Source code for gwinferno.numpyro_distributions

"""
a module defining additional numpyro distributions
"""

import jax.numpy as jnp
from jax import lax
from jax import random
from jax import vmap
from jax.lax import broadcast_shapes
from jax.scipy.integrate import trapezoid
from numpyro.distributions import Distribution
from numpyro.distributions import constraints
from numpyro.distributions.util import is_prng_key
from numpyro.distributions.util import promote_shapes
from numpyro.distributions.util import validate_sample

from .models.bsplines.smoothing import apply_difference_prior


[docs] def cumtrapz(y, x): difs = jnp.diff(x) idxs = jnp.array([i for i in range(1, len(y))]) res = jnp.cumsum(vmap(lambda i, d: d * (y[i] + y[i + 1]) / 2.0)(idxs, difs)) return jnp.concatenate([jnp.array([0]), res])
[docs] class Sine(Distribution): arg_constraints = { "minimum": constraints.real, "maximum": constraints.real, } reparametrized_params = ["minimum", "maximum"]
[docs] def __init__(self, minimum=0.0, maximum=jnp.pi, validate_args=None): self.minimum, self.maximum = promote_shapes(minimum, maximum) self._support = constraints.interval(minimum, maximum) batch_shape = lax.broadcast_shapes(jnp.shape(minimum), jnp.shape(maximum)) super(Sine, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return self.icdf(random.uniform(key, shape=sample_shape + self.batch_shape))
@validate_sample def log_prob(self, value): lp = jnp.log(jnp.sin(value) / 2.0) return jnp.where(jnp.isnan(lp), -jnp.inf, lp)
[docs] def cdf(self, value): cdf = jnp.atleast_1d((jnp.cos(value) - jnp.cos(self.minimum)) / (jnp.cos(self.maximum) - jnp.cos(self.minimum))) cdf = jnp.where(jnp.less(value, self.minimum), 0.0, cdf) cdf = jnp.where(jnp.greater(value, self.maximum), 1.0, cdf) return cdf
[docs] def icdf(self, q): norm = jnp.cos(self.minimum) - jnp.cos(self.maximum) return jnp.arccos(jnp.cos(self.minimum) - q * norm)
[docs] class Cosine(Distribution): arg_constraints = { "minimum": constraints.real, "maximum": constraints.real, } reparametrized_params = ["minimum", "maximum"]
[docs] def __init__(self, minimum=-jnp.pi / 2.0, maximum=jnp.pi / 2.0, validate_args=None): self.minimum, self.maximum = promote_shapes(minimum, maximum) self._support = constraints.interval(minimum, maximum) batch_shape = lax.broadcast_shapes(jnp.shape(minimum), jnp.shape(maximum)) super(Cosine, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return self.icdf(random.uniform(key, shape=sample_shape + self.batch_shape))
@validate_sample def log_prob(self, value): lp = jnp.log(jnp.cos(value) / 2.0) return jnp.where(jnp.isnan(lp), -jnp.inf, lp)
[docs] def cdf(self, value): cdf = jnp.atleast_1d((jnp.sin(value) - jnp.sin(self.minimum)) / (jnp.sin(self.maximum) - jnp.sin(self.minimum))) cdf = jnp.where(jnp.less(value, self.minimum), 0.0, cdf) cdf = jnp.where(jnp.greater(value, self.maximum), 1.0, cdf) return cdf
[docs] def icdf(self, q): norm = jnp.sin(self.minimum) - jnp.sin(self.maximum) return jnp.arcsin(jnp.sin(self.minimum) - q * norm)
[docs] class Powerlaw(Distribution): arg_constraints = { "minimum": constraints.real, "maximum": constraints.real, "alpha": constraints.real, } reparametrized_params = ["minimum", "maximum", "alpha"]
[docs] def __init__(self, alpha, minimum=0.0, maximum=1.0, low=0.0, high=1.0, validate_args=None): self.minimum, self.maximum, self.alpha = promote_shapes(minimum, maximum, alpha) self._support = constraints.interval(low, high) batch_shape = broadcast_shapes( jnp.shape(minimum), jnp.shape(maximum), jnp.shape(alpha), ) super(Powerlaw, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return self.icdf(random.uniform(key, shape=sample_shape + self.batch_shape))
@validate_sample def log_prob(self, value): logp = self.alpha * jnp.log(value) logp = logp + jnp.log((1.0 + self.alpha) / (self.maximum ** (1.0 + self.alpha) - self.minimum ** (1.0 + self.alpha))) logp_neg1 = -jnp.log(value) - jnp.log(self.maximum / self.minimum) return jnp.where( jnp.less(value, self.minimum) | jnp.greater(value, self.maximum), jnp.nan_to_num(-jnp.inf), jnp.where(jnp.equal(self.alpha, -1.0), logp_neg1, logp), )
[docs] def cdf(self, value): cdf = jnp.atleast_1d(value ** (self.alpha + 1.0) - self.minimum ** (self.alpha + 1.0)) / ( self.maximum ** (self.alpha + 1.0) - self.minimum ** (self.alpha + 1.0) ) cdf_neg1 = jnp.log(value / self.minimum) / jnp.log(self.maximum / self.minimum) cdf = jnp.where(jnp.equal(self.alpha, -1.0), cdf_neg1, cdf) cdf = jnp.minimum(cdf, 1.0) cdf = jnp.maximum(cdf, 0.0) return cdf
[docs] def icdf(self, q): icdf = (self.minimum ** (1.0 + self.alpha) + q * (self.maximum ** (1.0 + self.alpha) - self.minimum ** (1.0 + self.alpha))) ** ( 1.0 / (1.0 + self.alpha) ) icdf_neg1 = self.minimum * jnp.exp(q * jnp.log(self.maximum / self.minimum)) return jnp.where(jnp.equal(self.alpha, -1.0), icdf_neg1, icdf)
[docs] class PowerlawRedshift(Distribution): arg_constraints = { "maximum": constraints.positive, "lamb": constraints.real, } reparametrized_params = ["maximum", "lamb"]
[docs] def __init__(self, lamb, maximum, zgrid, dVcdz, low=0.0, high=1000.0, validate_args=None): self.maximum, self.lamb = promote_shapes(maximum, lamb) self._support = constraints.interval(low, high) batch_shape = broadcast_shapes( jnp.shape(maximum), jnp.shape(lamb), ) super(PowerlawRedshift, self).__init__(batch_shape=batch_shape, validate_args=validate_args) self.zs = zgrid self.dVdc_ = dVcdz self.pdfs = self.dVdc_ * (1 + self.zs) ** (lamb - 1) self.norm = trapezoid(self.pdfs, self.zs) self.pdfs /= self.norm self.cdfgrid = cumtrapz(self.pdfs, self.zs) self.cdfgrid = self.cdfgrid.at[-1].set(1)
@constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return self.icdf(random.uniform(key, shape=sample_shape + self.batch_shape))
@validate_sample def log_prob(self, value, dVdc=None): if dVdc is None: dVdc = jnp.interp(value, self.zs, self.dVdc_) return jnp.where( jnp.less_equal(value, self.maximum), jnp.log(dVdc) + (self.lamb - 1.0) * jnp.log(1.0 + value) - jnp.log(self.norm), jnp.nan_to_num(-jnp.inf), )
[docs] def cdf(self, value): return jnp.interp(value, self.zs, self.cdfgrid)
[docs] def icdf(self, q): return jnp.interp(q, self.cdfgrid, self.zs)
[docs] class PowerlawSmoothedPowerlaw(Distribution): arg_constraints = { "minimum": constraints.positive, "maximum": constraints.positive, "alpha": constraints.real, "alpha_max": constraints.positive, "alpha_min": constraints.positive, } reparametrized_params = ["minimum", "maximum", "alpha", "alpha_max", "alpha_min"]
[docs] def __init__(self, alpha, minimum, maximum, alpha_max, alpha_min, low, high, validate_args=None): self.minimum, self.maximum, self.alpha, self.alpha_max, self.alpha_min = promote_shapes(minimum, maximum, alpha, alpha_max, alpha_min) self.alpha_max = -self.alpha_max self._support = constraints.interval(low, high) self.low, self.high = low, high batch_shape = broadcast_shapes(jnp.shape(maximum), jnp.shape(minimum), jnp.shape(alpha), jnp.shape(alpha_max), jnp.shape(alpha_min)) super(PowerlawSmoothedPowerlaw, self).__init__(batch_shape=batch_shape, validate_args=validate_args) gamma = (self.alpha_min + 1) / (self.minimum ** (self.alpha_min + 1) - self.low ** (self.alpha_min + 1)) self.k1 = -gamma / ( 1 + gamma / (self.alpha + 1) * self.minimum ** (self.alpha_min - self.alpha) * (self.minimum ** (self.alpha + 1) - self.maximum ** (self.alpha + 1)) + gamma / (self.alpha_max + 1) * self.minimum ** (self.alpha_min - self.alpha) * self.maximum ** (self.alpha - self.alpha_max) * (self.maximum ** (self.alpha_max + 1) - self.high ** (self.alpha_max + 1)) ) self.k2 = self.k1 * self.minimum ** (self.alpha_min - self.alpha) self.k3 = self.k2 * self.maximum ** (self.alpha - self.alpha_max)
@constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) shape = sample_shape + self.batch_shape return jnp.ones(shape)
@validate_sample def log_prob(self, value): low_pl = jnp.where(jnp.less(value, self.minimum), jnp.log(self.k1) + jnp.log(value) * self.alpha_min, 0.0) high_pl = jnp.where(jnp.greater(value, self.maximum), jnp.log(self.k3) + jnp.log(value) * self.alpha_max, 0.0) mid_pl = jnp.where( jnp.greater_equal(value, self.minimum), jnp.where(jnp.less_equal(value, self.maximum), jnp.log(self.k2) + jnp.log(value) * self.alpha, 0.0), 0.0, ) return low_pl + mid_pl + high_pl
[docs] class BSplineDistribution(Distribution): arg_constraints = { "maximum": constraints.real, "minimum": constraints.real, "cs": constraints.real_vector, } reparametrized_params = ["maximum", "minimum", "cs"]
[docs] def __init__(self, minimum, maximum, cs, grid, grid_dmat, validate_args=None): self.maximum, self.minimum, self.cs = promote_shapes(maximum, minimum, cs) self._support = constraints.interval(minimum, maximum) batch_shape = lax.broadcast_shapes(jnp.shape(maximum), jnp.shape(minimum), jnp.shape(cs)) super(BSplineDistribution, self).__init__(batch_shape=batch_shape, validate_args=validate_args) self.grid = grid # grid_dmat will contain nan's where the grid is outside the support self.lpdfs = jnp.nan_to_num(jnp.einsum("i,i...->...", self.cs, grid_dmat), nan=-jnp.inf) self.pdfs = jnp.exp(self.lpdfs) self.norm = trapezoid(self.pdfs, self.grid) self.pdfs /= self.norm self.cdfgrid = cumtrapz(self.pdfs, self.grid) self.cdfgrid = self.cdfgrid.at[-1].set(1)
@constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return self.icdf(random.uniform(key, shape=sample_shape + self.batch_shape))
def _log_prob_nonorm(self, value): return jnp.interp(value, self.grid, self.lpdfs) @validate_sample def log_prob(self, value): return self._log_prob_nonorm(value) - jnp.log(self.norm)
[docs] def cdf(self, value): return jnp.interp(value, self.grid, self.cdfgrid)
[docs] def icdf(self, q): return jnp.interp(q, self.cdfgrid, self.grid)
[docs] class PSplineCoeficientPrior(Distribution): arg_constraints = {"inv_var": constraints.positive} reparametrized_params = ["inv_var"]
[docs] def __init__(self, N, inv_var, diff_order=2, validate_args=None): (self.inv_var,) = promote_shapes(inv_var) self._support = constraints.real_vector batch_shape = lax.broadcast_shapes(jnp.shape(inv_var)) super(PSplineCoeficientPrior, self).__init__(batch_shape=batch_shape, validate_args=validate_args, event_shape=(N,)) self.diff_order = diff_order self.N = N
@constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return self._support
[docs] def sample(self, key, sample_shape=()): assert is_prng_key(key) return jnp.ones(shape=sample_shape + self.batch_shape)
@validate_sample def log_prob(self, value): assert value.shape == (self.N,) return apply_difference_prior(value, self.inv_var, self.diff_order)