From 5aca6cbfc63025aa9b81ceec3c536e6c94ee042e Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Mon, 20 Jan 2025 18:21:05 +0500 Subject: [PATCH] feat: Levy distribution (#1943) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Levy distribution * fix: correct log scale calculation and update entropy method in Levy distribution * doc: add documentation and methods for Lévy distribution --- docs/source/distributions.rst | 8 +++ numpyro/distributions/__init__.py | 2 + numpyro/distributions/continuous.py | 99 +++++++++++++++++++++++++++++ test/test_distributions.py | 4 ++ 4 files changed, 113 insertions(+) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index e02da4d00..ba1a82725 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -256,6 +256,14 @@ Laplace :show-inheritance: :member-order: bysource +Levy +^^^^ +.. autoclass:: numpyro.distributions.continuous.Levy + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + LKJ ^^^ .. autoclass:: numpyro.distributions.continuous.LKJ diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index b23c7369f..864d2cfb2 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -32,6 +32,7 @@ InverseGamma, Kumaraswamy, Laplace, + Levy, LKJCholesky, Logistic, LogNormal, @@ -160,6 +161,7 @@ "Kumaraswamy", "Laplace", "LeftTruncatedDistribution", + "Levy", "LKJ", "LKJCholesky", "Logistic", diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index c4b5e1ea1..988389ef3 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -49,6 +49,7 @@ xlogy, ) from jax.scipy.stats import norm as jax_norm +from jax.typing import ArrayLike from numpyro.distributions import constraints from numpyro.distributions.discrete import _to_logits_bernoulli @@ -2966,3 +2967,101 @@ def infer_shapes( batch_shape = lax.broadcast_shapes(concentration, matrix[:-2]) event_shape = matrix[-2:] return batch_shape, event_shape + + +class Levy(Distribution): + r"""Lévy distribution is a special case of Lévy alpha-stable distribution. + Its probability density function is given by, + + .. math:: + f(x\mid \mu, c) = \sqrt{\frac{c}{2\pi(x-\mu)^{3}}} \exp\left(-\frac{c}{2(x-\mu)}\right), \qquad x > \mu + + where :math:`\mu` is the location parameter and :math:`c` is the scale parameter. + + :param loc: Location parameter. + :param scale: Scale parameter. + """ + + arg_constraints = { + "loc": constraints.positive, + "scale": constraints.positive, + } + + def __init__(self, loc, scale, *, validate_args=None): + self.loc, self.scale = promote_shapes(loc, scale) + batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) + self._support = constraints.greater_than(loc) + super(Levy, self).__init__(batch_shape, validate_args=validate_args) + + @constraints.dependent_property(is_discrete=False) + def support(self): + return self._support + + @validate_sample + def log_prob(self, value): + r"""Compute the log probability density function of the Lévy distribution. + + .. math:: + \log f(x\mid \mu, c) = \frac{1}{2}\log\left(\frac{c}{2\pi}\right) - \frac{c}{2(x-\mu)} + - \frac{3}{2}\log(x-\mu), \qquad x > \mu + + :param value: A batch of samples from the distribution. + :return: an array with shape `value.shape[:-self.event_shape]` + :rtype: numpy.ndarray + """ + shifted_value = value - self.loc + return -0.5 * ( + jnp.log(2.0 * jnp.pi) - jnp.log(self.scale) + self.scale / shifted_value + ) - 1.5 * jnp.log(shifted_value) + + def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> ArrayLike: + assert is_prng_key(key) + u = random.uniform(key, shape=sample_shape + self.batch_shape) + return self.icdf(u) + + def icdf(self, q: ArrayLike) -> ArrayLike: + r""" + The inverse cumulative distribution function of Lévy distribution is given by, + + .. math:: + F^{-1}(q\mid \mu, c) = \mu + c\left(\Phi^{-1}(1-q/2)\right)^{-2} + + where :math:`\Phi^{-1}` is the inverse of the standard normal cumulative distribution function. + + :param q: quantile values, should belong to [0, 1]. + :return: the samples whose cdf values equals to `q`. + """ + return self.loc + self.scale * jnp.power(ndtri(1 - 0.5 * q), -2) + + def cdf(self, value: ArrayLike) -> ArrayLike: + r"""The cumulative distribution function of Lévy distribution is given by, + + .. math:: + F(x\mid \mu, c) = 2 - 2\Phi\left(\sqrt{\frac{c}{x-\mu}}\right) + + where :math:`\Phi` is the standard normal cumulative distribution function. + + :param value: samples from Lévy distribution. + :return: output of the cumulative distribution function evaluated at `value`. + """ + inv_standardized = self.scale / (value - self.loc) + return 2.0 - 2.0 * ndtr(jnp.sqrt(inv_standardized)) + + @property + def mean(self) -> ArrayLike: + return jnp.broadcast_to(jnp.inf, self.batch_shape) + + @property + def variance(self) -> ArrayLike: + return jnp.broadcast_to(jnp.inf, self.batch_shape) + + def entropy(self) -> ArrayLike: + r"""If :math:`X \sim \text{Levy}(\mu, c)`, then the entropy of :math:`X` is given by, + + .. math:: + H(X) = \frac{1}{2}+\frac{3}{2}\gamma+\frac{1}{2}\ln{\left(16\pi c^2\right)} + + """ + return jnp.broadcast_to( + 0.5 + 1.5 * jnp.euler_gamma + 0.5 * jnp.log(16 * jnp.pi), self.batch_shape + ) + jnp.log(self.scale) diff --git a/test/test_distributions.py b/test/test_distributions.py index 1b0db1350..003c20b9c 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -456,6 +456,7 @@ def __init__( ), dist.Wishart: _wishart_to_scipy, _TruncatedNormal: _truncnorm_to_scipy, + dist.Levy: lambda loc, scale: osp.levy(loc=loc, scale=scale), } @@ -933,6 +934,9 @@ def get_sp_dist(jax_dist): T(dist.DoublyTruncatedPowerLaw, np.pi, 5.0, 50.0), T(dist.DoublyTruncatedPowerLaw, -1.0, 5.0, 50.0), T(dist.DoublyTruncatedPowerLaw, np.pi, 1.0, 2.0), + T(dist.Levy, 0.0, 1.0), + T(dist.Levy, 0.0, np.array([1.0, 2.0, 10.0])), + T(dist.Levy, np.array([1.0, 2.0, 10.0]), np.pi), ] DIRECTIONAL = [