Source code for gwinferno.pipeline.parser

"""
a module that stores tools for parsing CLI arguments and config files for analysis pipelines
"""

import sys
from argparse import ArgumentParser
from importlib import import_module

import jax.numpy as jnp
import yaml


[docs] class PopModel(object):
[docs] def __init__(self, model, params): self.model = model self.params = params
[docs] class PopPrior(object):
[docs] def __init__(self, dist, params): self.dist = dist self.params = params
[docs] class PopMixtureModel(PopModel):
[docs] def __init__(self, model, mix_dist, mix_params, components, component_params): self.model = model self.components = components self.mixing_dist = mix_dist self.mixing_params = mix_params self.component_params = component_params
[docs] def load_model_from_python_file(path): fn = path.split("/")[-1] direct = path.replace(f"/{fn}", "") sys.path.append(direct) return getattr(import_module(fn.replace(".py", "")), "model")
[docs] def load_dist_from_string(dist): split_d = dist.split(".") module = ".".join(split_d[:-1]) function = split_d[-1] return getattr(import_module(module), function)
[docs] class ConfigReader(object):
[docs] def __init__(self): self.models = {} self.priors = {} self.sampling_params = [] self.label = None self.outdir = None self.data_args = None self.sampler_args = None
def parse(self, yml_file): with open(yml_file, "r") as f: yml = yaml.safe_load(f) self.label = yml.pop("label", "label") self.outdir = yml.pop("outdir", "./") self.data_conf = yml.pop("data", {}) self.sampler_conf = yml.pop("sampler", {}) self.likelihood_kwargs = yml.pop("likelihood", {}) self.construct_model_and_prior_dicts(yml["models"]) def construct_model_and_prior_dicts(self, yml): if "python_file" in yml: self.models["file_path"] = yml["python_file"] else: for param in yml: if "Mixture" in yml[param]["model"]: self.add_mixture_model(param, yml[param]) else: self.add_model(param, yml[param]) def add_prior(self, key, subd): if "prior" in subd and "prior_params" in subd: for k in subd["prior_params"]: if type(subd["prior_params"][k]) is list: subd["prior_params"][k] = jnp.array(subd["prior_params"][k]) self.priors[key] = PopPrior(load_dist_from_string(subd["prior"]), subd["prior_params"]) self.sampling_params.append(key) elif "value" in subd: if type(subd["value"]) is list: self.priors[key] = jnp.array(subd["value"]) else: self.priors[key] = subd["value"] def add_model(self, param, subd): self.models[param] = PopModel(load_dist_from_string(subd["model"]), [p for p in subd["hyper_params"]]) for hp in subd["hyper_params"]: self.add_prior(f"{param}_{hp}", subd["hyper_params"][hp]) if "iid" in subd: self.add_iid_model(param, subd["iid"]["shared_parameter"]) def add_iid_model(self, param, shared_param): self.models[shared_param] = param def add_mixture_model(self, param, subd): model = load_dist_from_string(subd["model"]) mix_dist = load_dist_from_string(subd["mixture_dist"]["model"]) mix_params = [p for p in subd["mixture_dist"]["hyper_params"]] N = len(subd["mixture_dist"]["hyper_params"][mix_params[0]]["prior_params"]["concentration"]) for hp in mix_params: self.add_prior(f"{param}_mixture_dist_{hp}", subd["mixture_dist"]["hyper_params"][hp]) components = [] component_params = [] for i in range(N): name = f"component_{i + 1}" components.append(load_dist_from_string(subd[name]["model"])) component_params.append([p for p in subd[name]["hyper_params"]]) for hp in subd[name]["hyper_params"]: self.add_prior(f"{param}_component_{i + 1}_{hp}", subd[name]["hyper_params"][hp]) self.models[param] = PopMixtureModel(model, mix_dist, mix_params, components, component_params) if "iid" in subd[name]: self.add_iid_model(param, subd[name]["iid"]["shared_parameter"])
[docs] def load_base_parser(): parser = ArgumentParser() parser.add_argument("--data-dir", type=str, default="/home/bruce.edelman/projects/GWTC3_allevents/") parser.add_argument( "--inj-file", type=str, default="/home/bruce.edelman/projects/GWTC3_allevents/o1o2o3_mixture_injections.hdf5", ) parser.add_argument("--outdir", type=str, default="results") parser.add_argument("--mmin", type=float, default=3.0) parser.add_argument("--mmax", type=float, default=100.0) parser.add_argument("--chains", type=int, default=1) parser.add_argument("--samples", type=int, default=1500) parser.add_argument("--thinning", type=int, default=1) parser.add_argument("--warmup", type=int, default=500) parser.add_argument("--skip-inference", action="store_true", default=False) return parser