"""
a module for basic distribution pdf calculations with jax
"""
import jax.numpy as jnp
from jax.scipy.special import betaln
from jax.scipy.special import erf
"""
=============================================
This file contains some functions copied from https://github.com/ColmTalbot/gwpopulation re-implemented with jax.numpy
=============================================
"""
[docs]
def smooth(dx, x, xmin):
func = jnp.exp(dx / (x - xmin) + dx / (x - xmin - dx))
s1 = jnp.where(jnp.less(x, xmin), 0, 1)
s2 = jnp.where(jnp.less(x, xmin + dx) | jnp.greater_equal(x, xmin), (func + 1) ** (-1), s1)
return s2
[docs]
def logistic_function(x, L, k, x0):
"""
Logistic function or logistic curve
Args:
x (array_like): input array to evaluate logistic function at
L (float): curve's maximum value
k (float): logistic growth rate (positive values for left truncation, negative for right truncation)
x0 (float): x-value of the sigmoid's midpoint
Returns:
array_like: logistic function evaluated at x
"""
return L / (1 + jnp.exp(-k * (x - x0)))
[docs]
def logistic_unit(x, x0, sgn=1, sc=4):
"""
logistic_unit soft truncate a distribution with the logistic unit
Args:
x (array_like): input array to truncate
x0 (float): value of array we want to apply a soft truncation to
sgn (int, optional): Which side do we truncate on (1 for right, -1 for left). Defaults to 1.
sc (int, optional): scale of truncation, where higher values is sharper. Defaults to 4.
Returns:
array_like: input array with the soft truncation at x0 applied
"""
return logistic_function(x, 1.0, -1 * sgn * sc, x0)
[docs]
def log_logistic_unit(x, x0, sgn=1, sc=4):
"""
log_logistic_unit soft truncate a distribution with the log logistic unit
Args:
x (array_like): input array to truncate
x0 (float): value of array we want to apply a soft truncation to
Returns:
array_like: input array with the soft truncation at x0 applied
"""
diff = x - x0
return jnp.where(
jnp.less(diff * sgn * sc, 0),
jnp.log(logistic_unit(x, x0, sgn=sgn, sc=sc)),
-sgn * sc * (x - x0) + jnp.log(logistic_unit(x, x0, sgn=-sgn, sc=sc)),
)
[docs]
def powerlaw_logit_pdf(xx, alpha, low=None, high=None, low_fall_off=4.0, high_fall_off=4.0):
"""
powerlaw_logit_pdf pdf of high mass soft truncation powerlaw:
$$ p(x) \propto x^{\alpha}\Theta(x-x_\mathrm{min})\Theta(x_\mathrm{max}-x) $$
WARNING: this is not a normalized pdf!
Args:
xx (array_like): points to evaluate pdf at
alpha (float): power law index
low (float): low end truncation bound
high (float): high end truncation bound
fall_off (float): scale of logistic unit to truncate distribution
Returns:
array_like: pdf evaluated at xx
"""
prob = jnp.power(xx, alpha)
if low is not None:
prob *= logistic_unit(xx, low, sgn=-1.0, sc=low_fall_off)
if high is not None:
prob *= logistic_unit(xx, high, sgn=1.0, sc=high_fall_off)
return prob
[docs]
def powerlaw_pdf(xx, alpha, low, high, floor=0.0):
"""
powerlaw_pdf pdf of sharp truncated powerlaw:
Args:
xx (array_like): points to evaluate pdf at
alpha (float): power law index
low (float): low end truncation bound
high (float): high end truncation bound
floor (float, optional): lower bound of pdf (Defaults to 0.0)
"""
prob = jnp.power(xx, alpha)
norm = jnp.where(
alpha == -1,
1 / jnp.log(high / low),
(1 + alpha) / (high ** (1 + alpha) - low ** (1 + alpha)),
)
prob *= norm
return jnp.where(jnp.less(xx, low) | jnp.greater(xx, high), floor, prob)
[docs]
def truncnorm_pdf(xx, mu, sig, low, high, log=False):
"""
$$ p(x) \propto \mathcal{N}(x | \mu, \sigma)\Theta(x-x_\mathrm{min})\Theta(x_\mathrm{max}-x) $$
`log=True` makes this a log-normal distribution!
"""
if log:
prob = jnp.exp(-jnp.power(jnp.log(xx) - mu, 2) / (2 * sig**2))
continuous_norm = 1 / (xx * sig * (2 * jnp.pi) ** 0.5)
left_tail_cdf = 0.5 * (1 + erf((jnp.log(low) - mu) / (sig * (2**0.5))))
right_tail_cdf = 0.5 * (1 + erf((jnp.log(high) - mu) / (sig * (2**0.5))))
denom = right_tail_cdf - left_tail_cdf
else:
prob = jnp.exp(-jnp.power(xx - mu, 2) / (2 * sig**2))
continuous_norm = 1 / (sig * (2 * jnp.pi) ** 0.5)
left_tail_cdf = 0.5 * (1 + erf((low - mu) / (sig * (2**0.5))))
right_tail_cdf = 0.5 * (1 + erf((high - mu) / (sig * (2**0.5))))
denom = right_tail_cdf - left_tail_cdf
norm = continuous_norm / denom
return jnp.where(jnp.greater(xx, high) | jnp.less(xx, low), 0, prob * norm)
[docs]
def betadist(xx, alpha, beta, scale=1.0, floor=0.0):
"""
betadist pdf of Beta distribution evaluated at xx with optional max value of scale:
Args:
xx (array_like): points to evaluate pdf at
alpha (float): alpha shape parameter
beta (float): beta shape parameter
scale (float, optional): maximum value of support in Beta distribution. Defaults to 1.0.
floor (float, optional): lower bound of pdf (Defaults to 0.0)
Returns:
array_like: pdf evaluated at xx
"""
ln_beta = (alpha - 1) * jnp.log(xx) + (beta - 1) * jnp.log(scale - xx) - (alpha + beta - 1) * jnp.log(scale)
ln_beta = ln_beta - betaln(alpha, beta)
return jnp.where(jnp.less_equal(xx, scale) & jnp.greater_equal(xx, 0), jnp.exp(ln_beta), floor)