Source code for gwinferno.pipeline.analysis

"""
a module that stores the meat of the calculations for hierarchical population inference
"""

from functools import partial

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import jit
from jax import random
from jax.scipy.special import logsumexp
from numpyro.infer import SVI
from numpyro.infer import Trace_ELBO
from numpyro.infer import autoguide
from numpyro.optim import Adam

from .parser import PopMixtureModel
from .parser import PopModel

NP_KERNEL_MAP = {"NUTS": numpyro.infer.NUTS, "HMC": numpyro.infer.HMC}


[docs] def find_map(rng_key, numpyro_model, *model_args, Niter=100, lr=0.01): """Find the MAP estimate for a given NumPyro model using SVI with Adam optimizing the ELBO Parameters ---------- rng_key : jax.random.PRNGKey RNG Key to be passed to SVI.run(). numpyro_model : callable Python callable containing `numpyro.primitives`. Niter : int, optional Number of iterations to run variational inference. Defaults to 100. lr : float, optional learning rate used for Adam optimizer. Defaults to 0.01. Returns ------- SVIRunResult.params : dict parameters of the result of MAP optimization """ guide = autoguide.AutoDelta(numpyro_model) optimizer = Adam(lr) svi = SVI(numpyro_model, guide, optimizer, Trace_ELBO()) svi_results = svi.run(rng_key, Niter, *model_args) return svi_results.params
[docs] @partial(jit, static_argnames=["log"]) def per_event_log_bayes_factors(weights, log=False): r"""Calculates per-event log Bayes factors via importance sampling .. math:: \mathrm{BF}_i = \int p(d_i|\theta)p(\theta|\Lambda) d\theta \approx \frac{1}{N_s}\sum_{j=1}^{N_s} \frac{p(\theta|\Lambda)}{p(\theta|\Lambda_\emptyset)} Parameters ---------- weights : jax.DeviceArray JAX array of weights to integrate over. Expected size of `(N_events,N_samples)`. log : bool, optional Flag to perform calculations in log probability. Interprets weights as log weights. This should be more numerically stable in general but can break autograd with truncated models. Returns ------- jax.DeviceArray Array of per-event log bayes factors. jax.DeviceArray Array of per-event log effective samples sizes from Monte Carlo integrals. jax.DeviceArray Array of per-event estimated variances from log of Monte Carlo integrals. """ if log: logweights = weights logBFs = logsumexp(logweights, axis=1) logn_effs = 2 * logBFs - logsumexp(2 * logweights, axis=1) logBFs -= jnp.log(logweights.shape[1]) else: BFs = jnp.sum(weights, axis=1) n_effs = BFs**2 / jnp.sum(weights**2, axis=1) BFs /= weights.shape[1] logBFs = jnp.log(BFs) logn_effs = jnp.log(n_effs) variances = 1 / jnp.exp(logn_effs) - 1 / weights.shape[1] return logBFs, logn_effs, variances
[docs] @partial(jit, static_argnames=["log"]) def detection_efficiency(weights, Ninj, log=False): r"""Calculates the detection efficiency -- the expected fraction of sources detected from a population parameterized by :math:`\Lambda` -- estimated by importance sampling over the found injections from a fiducial population parameterized by :math:`\Lambda_\emptyset`: .. math:: \mu = \int P(\mathrm{det}|\theta)p(\theta|\Lambda) d\theta \approx \frac{1}{N_\mathrm{found}}\sum_{i=1}^{N_\mathrm{found}} \frac{p(\theta_i|\Lambda)}{p(\theta_i | \Lambda_\emptyset)} with Monte Carlo integration over found injections, along with the effective sample size of the Monte Carlo integral. Parameters ---------- weights : jax.DeviceArray JAX array of weights to integrate over. Expected size of (N_found_injections,). Ninj : int Total number of injections. log : bool, optional Flag to perform calculations in log probability. Interprets weights as log weights. This is slower but more numerically stable. Defaults to False. Returns ------- jax.DeviceArray Array of log detection efficiency. jax.DeviceArray Array of log N_eff from Monte Carlo integral. jax.DeviceArray Array of variance estimated from log of Monte Carlo integral. """ if log: logweights = weights logmu = logsumexp(logweights) - jnp.log(Ninj) mu = jnp.exp(logmu) var = jnp.sum(jnp.exp(logweights) ** 2) / Ninj**2 - mu**2 / Ninj logn_eff = 2 * logmu - jnp.log(var) else: mu = jnp.sum(weights) / Ninj var = jnp.sum(weights**2) / Ninj**2 - mu**2 / Ninj logmu = jnp.log(mu) logn_eff = 2 * logmu - jnp.log(var) variance = 1 / jnp.exp(logn_eff) - 1 / Ninj return logmu, logn_eff, variance
[docs] def hierarchical_likelihood( pe_weights, inj_weights, total_inj, Nobs, Tobs, surveyed_hypervolume=None, categorical=False, marginal_qs=False, indv_weights=None, rngkey=None, pop_frac=None, reconstruct_rate=True, marginalize_selection=False, min_neff_cut=True, max_variance_cut=False, posterior_predictive_check=False, param_names=None, pedata=None, injdata=None, m2min=3.0, m1min=5.0, mmax=100.0, log=False, ): """Performs the hierarchical likelihood calculation using importance sampling over injections and PE samples from each event's posterior samples assuming a fiducial prior density. Parameters ---------- pe_weights : jax.DeviceArray Array of weights evaluated at PE samples to integrate over. If `log=True` this is expected to be the **log** of the weights. Expected size of `(N_events,N_samples)`. inj_weights : jax.DeviceArray Array of weights evaluated at found injections to integrate over. If `log=True` this is expected to be the **log** of the weights. Expected size of `(N_found_injections,)`. total_inj : int Total number of generated injections before cutting on found. Nobs : int Total number of observed events analyzing. Tobs : float Time spent observing to produce catalog (in yrs). surveyed_hypervolume : float Total VT (normalization of the redshift model). categorical : bool, optional If `True` use latent categorical parameters to assign each event to one of many subpopulations. Defaults to `False`. marginal_qs : bool, optional TODO: add description! indv_weights : jax.DeviceArray TODO: add description! rngkey : jax.random.PRNGKey, optional RNG Key to be passed to sample categorical variable. Needed if `categorical=True`. Defaults to `None`. pop_frac : tuple of floats, optional Tuple of true astrophysical population fractions. Shape is number of categorical subpopulations, needs to sum to 1, and is needed if `categorical=True`. Defaults to `None`. marginalize_selection : bool, optional Flag to marginalize over uncertainty in selection monte carlo integral. Defaults to `True`. reconstruct_rate : bool, optional Flag to reconstruct marginalize merger rate. Defaults to `True`. min_neff_cut : bool, optional Flag to use the `min_neff` cut on the likelihood ensuring Monte Carlo integrals converge. Defaults to `True`. max_variance_cut : bool, optional Flag to use a cut on the maximum allowed variance < 1 estimated for the total log likelihood. If this is `True`, then `marginalize_selection` and `min_neff_cut` must be `False`. Defaults to `False`. posterior_predictive_check : bool, optional Flag to sample from the PE/injection data to perform posterior predictive check. Defaults to `False`. param_names : iterable, optional Parameters to sample for PPCs. Defaults to `None`. pedata : dict, optional Dictionary of PE data needed to perform PPCs. Defaults to `None`. injdata : dict, optional Dictionary of found injection data needed to perform PPCs. Defaults to `None`. m2min : float, optional Minimum mass for secondary components (solar masses). Defaults to `3.0`. m1min : float, optional Minimum mass for primary components (solar masses). Defaults to `5.0`. mmax : float, optional Maximum mass for primary components (solar masses). Defaults to `100.0`. log : bool, optional Flag to perform calculations in log space. Interprets weights as log weights. This is slower but more numerically stable. Defaults to `False`. Returns ------- float Marginalized merger rate in units of `Gpc^-3 yr^-1`. """ if max_variance_cut and (marginalize_selection or min_neff_cut): raise ValueError( "max_variance_cut is True which requires marginalize_selection and " "min_neff_cut to be False but got " f"marginalize_selection = {marginalize_selection} " f"and min_neff_cut = {min_neff_cut}", ) rate = None if categorical: with numpyro.plate("nObs", Nobs) as i: Qs = numpyro.sample( "Qs", dist.Categorical(probs=jnp.array(pop_frac)), rng_key=rngkey, ).reshape((-1, 1)) mix_pe_weights = jnp.where(Qs[i] == 0, pe_weights[0][i], pe_weights[1][i]) logBFs, logn_effs, variances = per_event_log_bayes_factors(mix_pe_weights, log=log) else: logBFs, logn_effs, variances = per_event_log_bayes_factors(pe_weights, log=log) log_det_eff, logn_eff_inj, variance = detection_efficiency(inj_weights, total_inj, log=log) numpyro.deterministic("log_nEff_inj", logn_eff_inj) numpyro.deterministic("log_nEffs", logn_effs) numpyro.deterministic("logBFs", logBFs) numpyro.deterministic("detection_efficiency", jnp.exp(log_det_eff)) numpyro.deterministic("variance_log_BFs", variances) numpyro.deterministic("variance_log_detection_efficiency", variance) if reconstruct_rate: total_vt = numpyro.deterministic("surveyed_hypervolume", surveyed_hypervolume / 1.0e9 * Tobs) unscaled_rate = numpyro.sample("unscaled_rate", dist.Gamma(Nobs)) rate = numpyro.deterministic("rate", unscaled_rate / jnp.exp(log_det_eff) / total_vt) if marginalize_selection: log_det_eff = log_det_eff - (3 + Nobs) / (2 * jnp.exp(logn_eff_inj)) if min_neff_cut: log_det_eff = jnp.where( jnp.greater_equal(logn_eff_inj, jnp.log(4 * Nobs)), log_det_eff, jnp.inf, ) sel = numpyro.deterministic( "selection_factor", jnp.where(jnp.isinf(log_det_eff), jnp.nan_to_num(-jnp.inf), -Nobs * log_det_eff), ) sumlogBFs = numpyro.deterministic("sum_logBFs", jnp.sum(logBFs)) log_l = sel + sumlogBFs log_l = numpyro.deterministic( "log_l", jnp.where( jnp.isnan(log_l), jnp.nan_to_num(-jnp.inf), jnp.nan_to_num(log_l), ), ) # TODO: clean this up, make value of min_neff a function kwarg if min_neff_cut: min_n_effs = jnp.exp(jnp.min(jnp.nan_to_num(logn_effs))) log_l = numpyro.deterministic( "neff_less_Nobs", jnp.where( jnp.less_equal(min_n_effs, Nobs), jnp.nan_to_num(-jnp.inf), log_l, ), ) variance = numpyro.deterministic( "variance_log_likelihood", Nobs**2 * variance + variances.sum(), ) if max_variance_cut: log_l = numpyro.deterministic( "variance_less_1", jnp.where( jnp.less_equal(variance, 1), log_l, jnp.nan_to_num(-jnp.inf), ), ) numpyro.factor("log_likelihood", log_l) if posterior_predictive_check: if param_names is not None and injdata is not None and pedata is not None: if log: pe_weights = jnp.exp(pe_weights) inj_weights = jnp.exp(inj_weights) cond = jnp.less(pedata["mass_1"], m1min) | jnp.greater(pedata["mass_1"], mmax) pe_weights = jnp.where( cond | jnp.less(pedata["mass_1"] * pedata["mass_ratio"], m2min), 0, pe_weights, ) inj_weights = jnp.where( jnp.less(injdata["mass_1"], m1min) | jnp.greater(injdata["mass_1"], mmax) | jnp.less(injdata["mass_1"] * injdata["mass_ratio"], m2min), 0, inj_weights, ) for ev in range(Nobs): k = random.PRNGKey(ev) k1, k2 = random.split(k) obs_idx = random.choice( k1, pe_weights.shape[1], p=pe_weights[ev, :] / jnp.sum(pe_weights[ev, :]), ) if marginal_qs: for i in range(len(indv_weights)): numpyro.deterministic(f"cat_frac_subpop_{i + 1}_event_{ev}", indv_weights[i][ev, obs_idx] / pe_weights[ev, obs_idx]) pred_idx = random.choice(k2, inj_weights.shape[0], p=inj_weights / jnp.sum(inj_weights)) for p in param_names: numpyro.deterministic(f"{p}_obs_event_{ev}", pedata[p][ev, obs_idx]) numpyro.deterministic(f"{p}_pred_event_{ev}", injdata[p][pred_idx]) return rate
[docs] def construct_hierarchical_model( model_dict, prior_dict, marginalize_selection=False, min_neff_cut=True, max_variance_cut=False, posterior_predictive_check=True, ): source_param_names = [k for k in model_dict.keys()] hyper_params = {k: None for k in prior_dict.keys()} pop_models = {k: None for k in model_dict.keys()} if "redshift" in pop_models.keys(): z_grid = jnp.linspace(1e-9, prior_dict["redshift_maximum"], 1000) def model(samps, injs, Ninj, Nobs, Tobs): for k, v in prior_dict.items(): try: hyper_params[k] = numpyro.sample(k, v.dist(**v.params)) except AttributeError: hyper_params[k] = v iid_mapping = {} for k, v in model_dict.items(): if isinstance(v, PopMixtureModel): components = [ v.components[i](**{p: hyper_params[f"{k}_component_{i + 1}_{p}"] for p in v.component_params[i]}) for i in range(len(v.components)) ] mixing_dist = v.mixing_dist(**{p: hyper_params[f"{k}_mixture_dist_{p}"] for p in v.mixing_params}) pop_models[k] = v.model(mixing_dist, components) elif isinstance(v, PopModel): hps = {p: hyper_params[f"{k}_{p}"] for p in v.params} if k == "redshift": hps["grid"] = z_grid pop_models[k] = v.model(**hps) elif isinstance(v, str): iid_mapping[v] = k else: raise ValueError(f"Unknown model type: {type(v)}:{v}") for shared_param, param in iid_mapping.items(): pop_models[shared_param] = pop_models[param] inj_weights = jnp.sum(jnp.array([pop_models[k].log_prob(injs[k]) for k in source_param_names]), axis=0) - jnp.log(injs["prior"]) pe_weights = jnp.sum(jnp.array([pop_models[k].log_prob(samps[k]) for k in source_param_names]), axis=0) - jnp.log(samps["prior"]) hierarchical_likelihood( pe_weights, inj_weights, total_inj=Ninj, Nobs=Nobs, Tobs=Tobs, surveyed_hypervolume=pop_models["redshift"].norm, marginalize_selection=marginalize_selection, min_neff_cut=min_neff_cut, max_variance_cut=max_variance_cut, posterior_predictive_check=posterior_predictive_check, pedata=samps, injdata=injs, param_names=source_param_names, m1min=2.0, m2min=2.0, mmax=100.0, log=True, ) return model