Skip to content

Commit

Permalink
Added scale parameterization to Exponential (#6677)
Browse files Browse the repository at this point in the history
  • Loading branch information
manulpatel authored Apr 28, 2023
1 parent 371472d commit a617bf2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
15 changes: 12 additions & 3 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,15 +1347,24 @@ class Exponential(PositiveContinuous):
----------
lam : tensor_like of float
Rate or inverse scale (``lam`` > 0).
scale: tensor_like of float
Alternative parameter (scale = 1/lam).
"""
rv_op = exponential

@classmethod
def dist(cls, lam: DIST_PARAMETER_TYPES, *args, **kwargs):
lam = pt.as_tensor_variable(floatX(lam))
def dist(cls, lam=None, scale=None, *args, **kwargs):
if lam is not None and scale is not None:
raise ValueError("Incompatible parametrization. Can't specify both lam and scale.")
elif lam is None and scale is None:
raise ValueError("Incompatible parametrization. Must specify either lam or scale.")

if scale is None:
scale = pt.reciprocal(lam)

scale = pt.as_tensor_variable(floatX(scale))
# PyTensor exponential op is parametrized in terms of mu (1/lam)
return super().dist([pt.reciprocal(lam)], **kwargs)
return super().dist([scale], **kwargs)

def moment(rv, size, mu):
if not rv_size_is_none(size):
Expand Down
16 changes: 16 additions & 0 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,15 @@ def test_exponential(self):
lambda q, lam: st.expon.ppf(q, loc=0, scale=1 / lam),
)

def test_exponential_wrong_arguments(self):
msg = "Incompatible parametrization. Can't specify both lam and scale"
with pytest.raises(ValueError, match=msg):
pm.Exponential.dist(lam=0.5, scale=5)

msg = "Incompatible parametrization. Must specify either lam or scale"
with pytest.raises(ValueError, match=msg):
pm.Exponential.dist()

def test_laplace(self):
check_logp(
pm.Laplace,
Expand Down Expand Up @@ -2091,6 +2100,13 @@ class TestExponential(BaseTestDistributionRandom):
]


class TestExponentialScale(BaseTestDistributionRandom):
pymc_dist = pm.Exponential
pymc_dist_params = {"scale": 5.0}
expected_rv_op_params = {"mu": pymc_dist_params["scale"]}
checks_to_run = ["check_pymc_params_match_rv_op"]


class TestCauchy(BaseTestDistributionRandom):
pymc_dist = pm.Cauchy
pymc_dist_params = {"alpha": 2.0, "beta": 5.0}
Expand Down

0 comments on commit a617bf2

Please sign in to comment.