Skip to content

Commit

Permalink
Allow logcdf inference in CustomDist
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Apr 15, 2023
1 parent e2eb26d commit 1ed4475
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
8 changes: 5 additions & 3 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,9 +684,11 @@ def rv_op(
def custom_dist_logp(op, values, size, *params, **kwargs):
return logp(values[0], *params[: len(dist_params)])

@_logcdf.register(rv_type)
def custom_dist_logcdf(op, value, size, *params, **kwargs):
return logcdf(value, *params[: len(dist_params)])
if logcdf is not None:

@_logcdf.register(rv_type)
def custom_dist_logcdf(op, value, size, *params, **kwargs):
return logcdf(value, *params[: len(dist_params)])

@_moment.register(rv_type)
def custom_dist_get_moment(op, rv, size, *params):
Expand Down
16 changes: 16 additions & 0 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,22 @@ def custom_dist(mu, sigma, size):
ip = m.initial_point()
np.testing.assert_allclose(m.compile_logp()(ip), ref_m.compile_logp()(ip))

def test_logcdf_inference(self):
def custom_dist(mu, sigma, size):
return pt.exp(pm.Normal.dist(mu, sigma, size=size))

mu = 1
sigma = 1.25
test_value = 0.9

custom_lognormal = CustomDist.dist(mu, sigma, dist=custom_dist)
ref_lognormal = LogNormal.dist(mu, sigma)

np.testing.assert_allclose(
pm.logcdf(custom_lognormal, test_value).eval(),
pm.logcdf(ref_lognormal, test_value).eval(),
)

def test_random_multiple_rngs(self):
def custom_dist(p, sigma, size):
idx = pm.Bernoulli.dist(p=p)
Expand Down

0 comments on commit 1ed4475

Please sign in to comment.