Source code for gwinferno.models.parametric.parametric

import jax.numpy as jnp
from jax.scipy.integrate import trapezoid

from gwinferno.cosmology import PLANCK_2015_LVK_Cosmology as Planck15

from ...distributions import betadist
from ...distributions import powerlaw_logit_pdf
from ...distributions import powerlaw_pdf
from ...distributions import smooth
from ...distributions import truncnorm_pdf

# subset Of Models From https://github.com/ColmTalbot/gwpopulation

"""
=============================================
This file contains a small subset of functions/models from https://github.com/ColmTalbot/gwpopulation re-implemented with jax.numpy
=============================================
"""

"""
***************************************
MASS MODELS
***************************************
"""


[docs] def powerlaw_primary_ratio_pdf(m1, q, alpha, beta, mmin, mmax): p_q = powerlaw_pdf(q, beta, mmin / m1, 1) p_m1 = powerlaw_pdf(m1, alpha, mmin, mmax) return p_q * p_m1
[docs] def powerlaw_primary_ratio_falloff_pdf(m1, q, alpha, beta, mmin, mmax, fall_off): p_q = powerlaw_pdf(q, beta, mmin / m1, 1) p_m1 = powerlaw_logit_pdf(m1, alpha, mmin, mmax, fall_off) return p_q * p_m1
[docs] def plpeak_primary_ratio_pdf(m1, q, alpha, beta, mmin, mmax, mpp, sigpp, lam, delta=None): p_q = powerlaw_pdf(q, beta, mmin / m1, 1) p_m1 = plpeak_primary_pdf(m1, alpha, mmin, mmax, mpp, sigpp, lam, delta=delta) if delta is None: return p_q * p_m1 else: return p_q * smooth(delta, q * m1, mmin) * p_m1
[docs] def plpeak_primary_pdf(m1, alpha, mmin, mmax, mpp, sigpp, lam, delta=None): if delta is None: return (1 - lam) * powerlaw_pdf(m1, alpha, mmin, mmax) + lam * truncnorm_pdf(m1, mpp, sigpp, mmin, mmax) else: return (1 - lam) * powerlaw_pdf(m1, alpha, mmin, mmax) * smooth(delta, m1, mmin) + lam * truncnorm_pdf(m1, mpp, sigpp, mmin, mmax)
""" *************************************** SPIN MODELS *************************************** """
[docs] def beta_spin_magnitude(a, alpha, beta, amax=1): return betadist(a, alpha, beta, scale=amax)
[docs] def iid_spin_magnitude(a1, a2, alpha_mag, beta_mag, amax=1): return betadist(a1, alpha_mag, beta_mag, scale=amax) * betadist(a2, alpha_mag, beta_mag, scale=amax)
[docs] def independent_spin_magnitude_beta_dist( a1, a2, alpha_mag1, beta_mag1, alpha_mag2, beta_mag2, amax1=1, amax2=1, ): return betadist(a1, alpha_mag1, beta_mag1, scale=amax1) * betadist(a2, alpha_mag2, beta_mag2, scale=amax2)
[docs] def mixture_isoalign_spin_tilt(ct, xi_tilt, sigma_tilt): cut = jnp.where(jnp.greater(ct, 1) | jnp.less(ct, -1), 0, 1) return cut * (1 - xi_tilt) / 2 + xi_tilt * truncnorm_pdf(ct, 1, sigma_tilt, -1, 1)
[docs] def iid_spin_tilt(ct1, ct2, xi_tilt, sigma_tilt): return mixture_isoalign_spin_tilt(ct1, xi_tilt, sigma_tilt) * mixture_isoalign_spin_tilt(ct2, xi_tilt, sigma_tilt)
[docs] def independent_spin_tilt(ct1, ct2, xi_tilt_1, xi_tilt_2, sigma_tilt1, sigma_tilt2): return mixture_isoalign_spin_tilt(ct1, xi_tilt_1, sigma_tilt1) * mixture_isoalign_spin_tilt(ct2, xi_tilt_2, sigma_tilt2)
[docs] def default_spin_tilt(ct1, ct2, xi_tilt, sigma_tilt): iso1 = jnp.where(jnp.greater(ct1, 1) | jnp.less(ct1, -1), 0, 0.5) iso2 = jnp.where(jnp.greater(ct2, 1) | jnp.less(ct2, -1), 0, 0.5) ali1 = truncnorm_pdf(ct1, 1, sigma_tilt, -1, 1) ali2 = truncnorm_pdf(ct2, 1, sigma_tilt, -1, 1) return (1 - xi_tilt) * iso1 * iso2 + xi_tilt * ali1 * ali2
""" *************************************** REDSHIFT MODELS *************************************** """
[docs] class PowerlawRedshiftModel(object):
[docs] def __init__(self, z_pe, z_inj): self.zmin = jnp.max(jnp.array([jnp.min(z_pe), jnp.min(z_inj)])) self.zmax = jnp.min(jnp.array([jnp.max(z_pe), jnp.max(z_inj)])) self.zs = jnp.linspace(self.zmin, self.zmax, 1000) self.dVdz_ = jnp.array(Planck15.dVcdz(self.zs)) self.dVdzs = [ jnp.array(Planck15.dVcdz(z_inj)), jnp.array(Planck15.dVcdz(z_pe)), ]
def normalization(self, lamb): return trapezoid(self.prob(self.zs, self.dVdz_, lamb), self.zs) def prob(self, z, dVdz, lamb): return dVdz * jnp.power(1.0 + z, lamb - 1.0) def log_prob(self, z, lamb): ndim = len(z.shape) dVdz = self.dVdzs[ndim - 1] return jnp.where( jnp.less_equal(z, self.zmax), jnp.log(dVdz) + (lamb - 1.0) * jnp.log(1.0 + z) - jnp.log(self.normalization(lamb)), jnp.nan_to_num(-jnp.inf), )
[docs] def __call__(self, z, lamb): ndim = len(z.shape) dVdz = self.dVdzs[ndim - 1] return jnp.where( jnp.less_equal(z, self.zmax), self.prob(z, dVdz, lamb) / self.normalization(lamb), 0, )