From 5fa64d269814b06d785c523a0542cc2ccd39ed06 Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Wed, 1 Jan 2025 17:15:01 +0500 Subject: [PATCH 1/3] feat: Levy distribution --- numpyro/distributions/__init__.py | 2 + numpyro/distributions/continuous.py | 71 +++++++++++++++++++++++++++++ test/test_distributions.py | 4 ++ 3 files changed, 77 insertions(+) diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index b23c7369f..864d2cfb2 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -32,6 +32,7 @@ InverseGamma, Kumaraswamy, Laplace, + Levy, LKJCholesky, Logistic, LogNormal, @@ -160,6 +161,7 @@ "Kumaraswamy", "Laplace", "LeftTruncatedDistribution", + "Levy", "LKJ", "LKJCholesky", "Logistic", diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index af17c4794..7b81a6a92 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -49,6 +49,7 @@ xlogy, ) from jax.scipy.stats import norm as jax_norm +from jax.typing import ArrayLike from numpyro.distributions import constraints from numpyro.distributions.discrete import _to_logits_bernoulli @@ -2966,3 +2967,73 @@ def infer_shapes( batch_shape = lax.broadcast_shapes(concentration, matrix[:-2]) event_shape = matrix[-2:] return batch_shape, event_shape + + +class Levy(Distribution): + r"""Lévy distribution is a special case of Lévy alpha-stable distribution. + Its probability density function is given by, + + .. math:: + f(x\mid \mu, c) = \sqrt{\frac{c}{2\pi(x-\mu)^{3}}} \exp\left(-\frac{c}{2(x-\mu)}\right), \qquad x > \mu + + where :math:`\mu` is the location parameter and :math:`c` is the scale parameter. + + :param loc: Location parameter. + :param scale: Scale parameter. + """ + + arg_constraints = { + "loc": constraints.positive, + "scale": constraints.positive, + } + + def __init__(self, loc, scale, *, validate_args=None): + self.loc, self.scale = promote_shapes(loc, scale) + batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) + self._support = constraints.greater_than(loc) + super(Levy, self).__init__(batch_shape, validate_args=validate_args) + + @constraints.dependent_property(is_discrete=False) + def support(self): + return self._support + + @validate_sample + def log_prob(self, value): + r"""Compute the log probability density function of the Lévy distribution. + + .. math:: + \log f(x\mid \mu, c) = \frac{1}{2}\log\left(\frac{c}{2\pi}\right) - \frac{c}{2(x-\mu)} + - \frac{3}{2}\log(x-\mu), \qquad x > \mu + """ + shifted_value = value - self.loc + return -0.5 * ( + jnp.log(2.0 * jnp.pi * self.scale) + self.scale / shifted_value + ) - 1.5 * jnp.log(shifted_value) + + def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> ArrayLike: + assert is_prng_key(key) + u = random.uniform(key, shape=sample_shape + self.batch_shape) + return self.icdf(u) + + def icdf(self, q: ArrayLike) -> ArrayLike: + return self.loc + self.scale * jnp.power(ndtri(1 - 0.5 * q), -2) + + def cdf(self, value: ArrayLike) -> ArrayLike: + inv_standardized = self.scale / (value - self.loc) + return 2.0 - 2.0 * ndtr(jnp.sqrt(inv_standardized)) + + @property + def mean(self) -> ArrayLike: + return jnp.broadcast_to(jnp.inf, self.batch_shape) + + @property + def variance(self) -> ArrayLike: + return jnp.broadcast_to(jnp.inf, self.batch_shape) + + def entropy(self) -> ArrayLike: + return ( + 0.5 + + 1.5 * jnp.euler_gamma + + 0.5 * jnp.log(16 * jnp.pi) + + jnp.log(self.scale) + ) diff --git a/test/test_distributions.py b/test/test_distributions.py index 03c5813d6..5900ad0f1 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -456,6 +456,7 @@ def __init__( ), dist.Wishart: _wishart_to_scipy, _TruncatedNormal: _truncnorm_to_scipy, + dist.Levy: lambda loc, scale: osp.levy(loc=loc, scale=scale), } @@ -933,6 +934,9 @@ def get_sp_dist(jax_dist): T(dist.DoublyTruncatedPowerLaw, np.pi, 5.0, 50.0), T(dist.DoublyTruncatedPowerLaw, -1.0, 5.0, 50.0), T(dist.DoublyTruncatedPowerLaw, np.pi, 1.0, 2.0), + T(dist.Levy, 0.0, 1.0), + T(dist.Levy, 0.0, np.array([1.0, 2.0, 10.0])), + T(dist.Levy, np.array([1.0, 2.0, 10.0]), np.pi), ] DIRECTIONAL = [ From 3e633eef33e84c4e03f2bb87d504d3cf006037a5 Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Wed, 1 Jan 2025 17:56:42 +0500 Subject: [PATCH 2/3] fix: correct log scale calculation and update entropy method in Levy distribution --- numpyro/distributions/continuous.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 7b81a6a92..69109641c 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -3007,7 +3007,7 @@ def log_prob(self, value): """ shifted_value = value - self.loc return -0.5 * ( - jnp.log(2.0 * jnp.pi * self.scale) + self.scale / shifted_value + jnp.log(2.0 * jnp.pi) - jnp.log(self.scale) + self.scale / shifted_value ) - 1.5 * jnp.log(shifted_value) def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> ArrayLike: @@ -3031,9 +3031,6 @@ def variance(self) -> ArrayLike: return jnp.broadcast_to(jnp.inf, self.batch_shape) def entropy(self) -> ArrayLike: - return ( - 0.5 - + 1.5 * jnp.euler_gamma - + 0.5 * jnp.log(16 * jnp.pi) - + jnp.log(self.scale) - ) + return jnp.broadcast_to( + 0.5 + 1.5 * jnp.euler_gamma + 0.5 * jnp.log(16 * jnp.pi), self.batch_shape + ) + jnp.log(self.scale) From 11cfe85bd548a2a6acd857ca909038e634371228 Mon Sep 17 00:00:00 2001 From: Qazalbash Date: Fri, 3 Jan 2025 01:47:10 +0500 Subject: [PATCH 3/3] =?UTF-8?q?doc:=20add=20documentation=20and=20methods?= =?UTF-8?q?=20for=20L=C3=A9vy=20distribution?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/distributions.rst | 8 ++++++++ numpyro/distributions/continuous.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index e02da4d00..ba1a82725 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -256,6 +256,14 @@ Laplace :show-inheritance: :member-order: bysource +Levy +^^^^ +.. autoclass:: numpyro.distributions.continuous.Levy + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + LKJ ^^^ .. autoclass:: numpyro.distributions.continuous.LKJ diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 69109641c..26ebeadf1 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -3004,6 +3004,10 @@ def log_prob(self, value): .. math:: \log f(x\mid \mu, c) = \frac{1}{2}\log\left(\frac{c}{2\pi}\right) - \frac{c}{2(x-\mu)} - \frac{3}{2}\log(x-\mu), \qquad x > \mu + + :param value: A batch of samples from the distribution. + :return: an array with shape `value.shape[:-self.event_shape]` + :rtype: numpy.ndarray """ shifted_value = value - self.loc return -0.5 * ( @@ -3016,9 +3020,30 @@ def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> ArrayLik return self.icdf(u) def icdf(self, q: ArrayLike) -> ArrayLike: + r""" + The inverse cumulative distribution function of Lévy distribution is given by, + + .. math:: + F^{-1}(q\mid \mu, c) = \mu + c\left(\Phi^{-1}(1-q/2)\right)^{-2} + + where :math:`\Phi^{-1}` is the inverse of the standard normal cumulative distribution function. + + :param q: quantile values, should belong to [0, 1]. + :return: the samples whose cdf values equals to `q`. + """ return self.loc + self.scale * jnp.power(ndtri(1 - 0.5 * q), -2) def cdf(self, value: ArrayLike) -> ArrayLike: + r"""The cumulative distribution function of Lévy distribution is given by, + + .. math:: + F(x\mid \mu, c) = 2 - 2\Phi\left(\sqrt{\frac{c}{x-\mu}}\right) + + where :math:`\Phi` is the standard normal cumulative distribution function. + + :param value: samples from Lévy distribution. + :return: output of the cumulative distribution function evaluated at `value`. + """ inv_standardized = self.scale / (value - self.loc) return 2.0 - 2.0 * ndtr(jnp.sqrt(inv_standardized)) @@ -3031,6 +3056,12 @@ def variance(self) -> ArrayLike: return jnp.broadcast_to(jnp.inf, self.batch_shape) def entropy(self) -> ArrayLike: + r"""If :math:`X \sim \text{Levy}(\mu, c)`, then the entropy of :math:`X` is given by, + + .. math:: + H(X) = \frac{1}{2}+\frac{3}{2}\gamma+\frac{1}{2}\ln{\left(16\pi c^2\right)} + + """ return jnp.broadcast_to( 0.5 + 1.5 * jnp.euler_gamma + 0.5 * jnp.log(16 * jnp.pi), self.batch_shape ) + jnp.log(self.scale)