Source code for gwinferno.models.spline_perturbation

"""
a module that stores spline perturbation related population models
"""

import jax.numpy as jnp
import numpy as np
from jax.scipy.integrate import trapezoid

from ..distributions import powerlaw_pdf
from ..interpolation import BSpline
from ..interpolation import LogXBSpline
from .parametric.parametric import PowerlawRedshiftModel


[docs] class PowerlawBasisSplinePrimaryPowerlawRatio(object):
[docs] def __init__( self, n_splines_m: int, m1pe: dict, m1inj: dict, mmin: float = 3.0, m2min: float = 3.0, mmax: float = 100.0, k: int = 4, basis: BSpline = BSpline, **kwargs ): """ __init__ Args: n_splines_m (int): Number of basis functions used to create the B-Spline in primary mass. m1pe (dict): Dictionary with m1's parameter estimation. m1inj (dict): Dictionary with m1's injection samples. mmin (float, optional): Minimum primary mass distribution cutoff. Defaults to 3. m2min (float, optional): Minimum secondary mass. Defaults to 3. mmax (float, optional): Maximum primary mass distribution cutoff. Defaults to 100. k (int, optional): Power of the polynomials used in the B-Spline. Defaults to 4. basis (object, optional): The type of basis class you wish to use. Defaults to BSpline. """ self.m2min = m2min self.n_splines_m = n_splines_m self.mmin = mmin self.mmax = mmax self.ms = jnp.linspace(mmin, mmax, 1000) self.n_splines = n_splines_m interior_knots = np.linspace(np.log(mmin), np.log(mmax), n_splines_m - k + 2) dx = interior_knots[1] - interior_knots[0] knots = np.concatenate( [ np.log(mmin) - dx * np.arange(1, k)[::-1], interior_knots, np.log(mmax) + dx * np.arange(1, k), ] ) self.knots = knots self.interpolator = basis(n_splines_m, knots=knots, interior_knots=interior_knots, xrange=(np.log(mmin), np.log(mmax)), k=4, **kwargs) self.pe_design_matrix = jnp.array(self.interpolator.bases(np.log(m1pe))) self.inj_design_matrix = jnp.array(self.interpolator.bases(np.log(m1inj))) self.dmats = [self.inj_design_matrix, self.pe_design_matrix] self.norm_design_matrix = jnp.array(self.interpolator.bases(np.log(self.ms)))
[docs] def smoothing(self, ms: jnp.ndarray, mmin: float, delta_m: float): """ smoothing Args: ms (jnp.ndarray): Black hole masses mmin (float): minimum black hole mass delta_m (float): size of BH grid Returns: _type_: """ sm = ms - mmin smoothing_region = jnp.greater(sm, 0) & jnp.less(sm, delta_m) window = jnp.where( smoothing_region, 1.0 / (jnp.exp(delta_m / sm + delta_m / (sm - delta_m)) + 1.0), 1, ) window = jnp.where(jnp.isinf(window) | jnp.isnan(window), 1, window) return jnp.where(jnp.less_equal(ms, mmin), 0, window)
[docs] def p_m1(self, m1: jnp.ndarray, alpha: float, mmin: float, mmax: float, cs: jnp.ndarray): """ p_m1 Probability distribution of primary masses Args: m1 (jnp.ndarray): Ndarray of primary (m1) masses alpha (float): Power-law index mmin (float): Minimum primary mass cutoff mmax (float): Maximum primary mass cutoff cs (jnp.ndarray): B-spline coefficients Returns: _type_: Probability of primary mass """ p_m = powerlaw_pdf(m1, alpha=-alpha, low=mmin, high=mmax) ndim = len(m1.shape) perturbation = jnp.exp(self.interpolator.project(self.dmats[ndim - 1], cs)) norm = self.norm_p_m1(alpha=alpha, mmin=mmin, mmax=mmax, cs=cs) return p_m * perturbation / norm
[docs] def norm_p_m1(self, alpha: float, mmin: float, mmax: float, cs: jnp.ndarray): """ norm_p_m1 Normalized probability distribution of primary mass Args: alpha (float): Power of the powerlaw mmin (float): Minimum primary mass cutoff mmax (float): Maximum primary mass cutoff cs (jnp.ndarray): B-spline coefficients Returns: _type_: Normalized probability of primary mass """ p_m = powerlaw_pdf(self.ms, alpha=-alpha, low=mmin, high=mmax) perturbation = jnp.exp(self.interpolator.project(self.norm_design_matrix, cs)) return trapezoid(y=p_m * perturbation, x=self.ms)
[docs] def p_q(self, q: jnp.ndarray, m1: jnp.ndarray, beta: float): """ p_q Probability of mass ratio Args: q (jnp.ndarray): Mass ratio m1 (jnp.ndarray): Primary mass beta (float): Power-law index Returns: _type_: Probability of mass ratio """ p_q = powerlaw_pdf(q, alpha=beta, low=self.m2min / m1, high=1) return p_q
[docs] def __call__(self, m1: jnp.ndarray, q: jnp.ndarray, **kwargs) -> jnp.ndarray: """ __call__ Args: m1 (jnp.ndarray): Primary masses q (jnp.ndarray): Mass ratio Returns: jnp.ndarray: _description_ """ beta = kwargs.pop("beta") p_m1 = self.p_m1(m1, **kwargs) p_q = self.p_q(q, m1, beta=beta) return p_m1 * p_q
[docs] class PowerlawBasisSplinePrimaryRatio(object):
[docs] def __init__( self, n_splines_m: int, n_splines_q: int, m1pe: dict, qpe: dict, m1inj: dict, qinj: dict, mmin: float = 2.0, mmax: float = 100.0, k: int = 4 ): """ __init__ Args: n_splines_m (int): Number of basis functions used to create the B-Spline in primary mass. n_splines_q (int): Number of basis functions used to create the B-Spline for the mass ratio. m1pe (dict): Dictionary with m1's parameter estimation. qpe (dict): Dictionary with mass ratio parameter estimation. m1inj (dict): Dictionary with m1's injection samples. qinj (dict): Dictionary with mass ratio injection samples. mmin (float, optional): Minimum primary mass cutoff. Defaults to 2. mmax (float, optional): Maximum primary mass cutoff. Defaults to 100. k (int, optional): Power of the polynomials used in the B-Spline. Defaults to 4. """ self.n_splines_m = n_splines_m self.n_splines_q = n_splines_q self.mmin = mmin self.mmax = mmax self.ms = jnp.linspace(mmin, mmax, 1000) self.qs = jnp.linspace(mmin / mmax, 1, 500) self.mm, self.qq = jnp.meshgrid(self.ms, self.qs) interior_mknots = np.linspace(np.log(mmin), np.log(mmax), n_splines_m - k + 2) dx = interior_mknots[1] - interior_mknots[0] knotsm = np.concatenate( [ np.log(mmin) - dx * np.arange(1, k)[::-1], interior_mknots, np.log(mmax) + dx * np.arange(1, k), ] ) self.knotsm = knotsm interior_qknots = np.linspace(0, 1, n_splines_q - k + 2) dxq = interior_qknots[1] - interior_qknots[0] knotsq = np.concatenate( [ -dxq * np.arange(1, k)[::-1], interior_qknots, 1 + dxq * np.arange(1, k), ] ) self.knotsq = knotsq self.interpolator = BSpline( n_splines_m, knots=knotsm, interior_knots=interior_mknots, xrange=(np.log(mmin), np.log(mmax)), k=4, ) self.pe_design_matrix = jnp.array(self.interpolator.bases(np.log(m1pe))) self.inj_design_matrix = jnp.array(self.interpolator.bases(np.log(m1inj))) self.dmats = [self.inj_design_matrix, self.pe_design_matrix] self.qinterpolator = BSpline( n_splines_q, knots=knotsq, interior_knots=interior_qknots, xrange=(0, 1), k=4, ) self.qpe_design_matrix = jnp.array(self.qinterpolator.bases(qpe)) self.qinj_design_matrix = jnp.array(self.qinterpolator.bases(qinj)) self.qdmats = [self.qinj_design_matrix, self.qpe_design_matrix] self.qshapes = [(self.qknots, 1), (self.qknots, 1, 1)] self.norm_design_matrix = jnp.array(self.interpolator.bases(np.log(self.mm))) self.qnorm_design_matrix = jnp.array(self.qinterpolator.bases(self.qq))
[docs] def p_m1(self, m1: jnp.ndarray, alpha: float, mmin: float, mmax: float, cs: jnp.ndarray): """ p_m1 Probability distribution of primary masses Args: m1 (jnp.ndarray): Ndarray of primary (m1) masses alpha (float): Power-law index mmin (float): Minimum primary mass cutoff mmax (float): Maximum primary mass cutoff cs (jnp.ndarray): B-Spline coefficients Returns: _type_: Probability of primary mass """ p_m = powerlaw_pdf(m1, alpha=-alpha, low=mmin, high=mmax) ndim = len(m1.shape) perturbation = jnp.exp(self.interpolator.project(self.dmats[ndim - 1], cs)) return p_m * perturbation
[docs] def norm_pm1q(self, alpha: float, mmin: float, mmax: float, cs: jnp.ndarray, beta: float, vs: jnp.ndarray): """ norm_pm1q Normalized (primary mass/mass ratio) distribution Args: alpha (_type_): Power of the power-law mmin (_type_): Minimum primary mass cutoff mmax (_type_): Maximum primary mass cutoff cs (_type_): B-Spline coefficients beta (_type_): Mass ratio power-law index vs (_type_): B-Spline coefficients for the mass ratio Returns: _type_: Normalized probability of (primary mass/ mass ratio) """ p_m = powerlaw_pdf(self.mm, alpha=-alpha, low=mmin, high=mmax) perturbation = jnp.exp(self.interpolator.project(self.norm_design_matrix, cs)) p_q = powerlaw_pdf(self.qq, alpha=beta, low=mmin / self.mm, high=1) qperturbation = jnp.exp(self.qinterpolator.project(self.qnorm_design_matrix, vs)) p_mq = p_m * perturbation * p_q * qperturbation return trapezoid(trapezoid(p_mq, self.qs, axis=0), self.ms)
[docs] def p_q(self, q: jnp.ndarray, m1: jnp.ndarray, beta: float, mmin: float, vs: jnp.ndarray): """ p_q Probability of mass ratio Args: q (jnp.ndarray): Mass ratio m1 (jnp.ndarray): Primary mass beta (float): Mass ratio power-law index mmin (float): Minimum primary mass cutoff vs (jnp.ndarray): B-Spline coefficients Returns: _type_: Probability of mass ratio """ p_q = powerlaw_pdf(q, alpha=beta, low=mmin / m1, high=1) ndim = len(q.shape) perturbation = jnp.exp(self.qinterpolator.project(self.qdmats[ndim - 1], vs)) return p_q * perturbation
[docs] def __call__(self, m1: jnp.ndarray, q: jnp.ndarray, **kwargs) -> jnp.ndarray: """ __call__ Args: m1 (jnp.ndarray): Primary mass q (jnp.ndarray): Mass ratio Returns: jnp.ndarray: """ beta = kwargs.pop("beta") mmin = kwargs.pop("mmin", self.mmin) vs = kwargs.pop("vs") p_m1 = self.p_m1(m1, mmin=mmin, **kwargs) p_q = self.p_q(q, m1, beta=beta, mmin=mmin, vs=vs) norm = self.norm_pm1q(beta=beta, mmin=mmin, vs=vs, **kwargs) return p_m1 * p_q / norm
[docs] class PowerlawSplineRedshiftModel(PowerlawRedshiftModel):
[docs] def __init__(self, n_splines: int, z_pe: dict, z_inj: dict, basis: LogXBSpline = LogXBSpline): """ __init__ Args: n_splines (int): Number of basis functions used to create B-Spline z_pe (dict): Redshift parameter estimation z_inj (dict): Redshift injections basis (LogXBSpline, optional): Bases to be used in the spline perturbation. Defaults to LogXBSpline. """ super().__init__(z_pe=z_pe, z_inj=z_inj) self.n_splines = n_splines self.interpolator = basis(n_splines, xrange=(self.zmin, self.zmax), k=4, normalize=False) self.pe_design_matrix = jnp.array(self.interpolator.bases(z_pe)) self.inj_design_matrix = jnp.array(self.interpolator.bases(z_inj)) self.dmats = [self.inj_design_matrix, self.pe_design_matrix] self.norm_design_matrix = jnp.array(self.interpolator.bases(self.zs))
[docs] def normalization(self, lamb: float, cs: jnp.ndarray): """ normalization Args: lamb (float): Power-law exponent for the redshift model cs (jnp.ndarray): B-Spline coefficients Returns: _type_: """ pz = self.dVdz_ * jnp.power(1.0 + self.zs, lamb - 1) pz *= jnp.exp(self.interpolator.project(self.norm_design_matrix, cs)) return trapezoid(pz, self.zs)
[docs] def prob(self, z: jnp.ndarray, dVdz: jnp.ndarray, lamb: float, cs: jnp.ndarray): """ prob Returns probability Args: z (jnp.ndarray): Redshift dV_cdz (jnp.ndarray): Differential co-moving volume element with respect to redshift. lamb (float): Power-law exponent for redshift model cs (jnp.ndarray): B-Spline coefficients Returns: _type_: """ ndim = len(z.shape) return dVdz * jnp.power(1.0 + z, lamb - 1.0) * jnp.exp(self.interpolator.project(self.dmats[ndim - 1], cs))
[docs] def __call__(self, z: jnp.ndarray, lamb: float, cs: jnp.ndarray) -> jnp.ndarray: """ __call__ Args: z (jnp.ndarray): Redshift lamb (float): Power-law exponent for redshift model cs (jnp.ndarray): B-Spline coefficients Returns: jnp.ndarray: """ ndim = len(z.shape) dVdz = self.dVdzs[ndim - 1] return jnp.where( jnp.less_equal(z, self.zmax), self.prob(z, dVdz, lamb, cs) / self.normalization(lamb, cs), 0, )