Skip to content

Commit

Permalink
Add logcdf methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Nov 2, 2021
1 parent ffa7110 commit 9e9d530
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 4 deletions.
73 changes: 72 additions & 1 deletion aeppl/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ def logprob(rv_var, *rv_values, **kwargs):
return logprob


def logcdf(rv_var, rv_value, **kwargs):
"""Create a graph for the logcdf of a ``RandomVariable``."""
logcdf = _logcdf(
rv_var.owner.op, rv_value, *rv_var.owner.inputs, name=rv_var.name, **kwargs
)

if rv_var.name:
logcdf.name = f"{rv_var.name}_logcdf"

return logcdf


@singledispatch
def _logprob(
op: Op,
Expand All @@ -52,7 +64,23 @@ def _logprob(
for a ``RandomVariable``, register a new function on this dispatcher.
"""
raise NotImplementedError()
raise NotImplementedError(f"Logprob method not implemented for {op}")


@singledispatch
def _logcdf(
op: Op,
value: TensorVariable,
*inputs: TensorVariable,
**kwargs,
):
"""Create a graph for the logcdf of a ``RandomVariable``.
This function dispatches on the type of ``op``, which should be a subclass
of ``RandomVariable``. If you want to implement new logcdf graphs
for a ``RandomVariable``, register a new function on this dispatcher.
"""
raise NotImplementedError(f"Logcdf method not implemented for {op}")


@_logprob.register(arb.UniformRV)
Expand All @@ -66,6 +94,24 @@ def uniform_logprob(op, values, *inputs, **kwargs):
)


@_logcdf.register(arb.UniformRV)
def uniform_logcdf(op, value, *inputs, **kwargs):
lower, upper = inputs[3:]

res = at.switch(
at.lt(value, lower),
-np.inf,
at.switch(
at.lt(value, upper),
at.log(value - lower) - at.log(upper - lower),
0,
),
)

res = Assert("lower <= upper")(res, at.all(at.le(lower, upper)))
return res


@_logprob.register(arb.NormalRV)
def normal_logprob(op, values, *inputs, **kwargs):
(value,) = values
Expand All @@ -79,6 +125,21 @@ def normal_logprob(op, values, *inputs, **kwargs):
return res


@_logcdf.register(arb.NormalRV)
def normal_logcdf(op, value, *inputs, **kwargs):
mu, sigma = inputs[3:]

z = (value - mu) / sigma
res = at.switch(
at.lt(z, -1.0),
at.log(at.erfcx(-z / at.sqrt(2.0)) / 2.0) - at.sqr(z) / 2.0,
at.log1p(-at.erfc(z / at.sqrt(2.0)) / 2.0),
)

res = Assert("sigma > 0")(res, at.all(at.gt(sigma, 0.0)))
return res


@_logprob.register(arb.HalfNormalRV)
def halfnormal_logprob(op, values, *inputs, **kwargs):
(value,) = values
Expand Down Expand Up @@ -346,6 +407,16 @@ def poisson_logprob(op, values, *inputs, **kwargs):
return res


@_logcdf.register(arb.PoissonRV)
def poisson_logcdf(op, value, *inputs, **kwargs):
(mu,) = inputs[3:]
value = at.floor(value)
res = at.log(at.gammaincc(value + 1, mu))
res = at.switch(at.le(0, value), res, -np.inf)
res = Assert("0 <= mu")(res, at.all(at.le(0.0, mu)))
return res


@_logprob.register(arb.NegBinomialRV)
def nbinom_logprob(op, values, *inputs, **kwargs):
(value,) = values
Expand Down
81 changes: 78 additions & 3 deletions tests/test_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import scipy.stats as stats

from aeppl.logprob import logprob
from aeppl.logprob import logcdf, logprob

# @pytest.fixture(scope="module", autouse=True)
# def set_aesara_flags():
Expand Down Expand Up @@ -33,7 +33,7 @@ def create_aesara_params(dist_params, obs, size):


def scipy_logprob_tester(
rv_var, obs, dist_params, test_fn=None, check_broadcastable=True
rv_var, obs, dist_params, test_fn=None, check_broadcastable=True, test_logcdf=False
):
"""Test for correspondence between `RandomVariable` and NumPy shape and
broadcast dimensions.
Expand All @@ -46,7 +46,10 @@ def scipy_logprob_tester(

test_fn = getattr(stats, name)

aesara_res = logprob(rv_var, at.as_tensor(obs))
if not test_logcdf:
aesara_res = logprob(rv_var, at.as_tensor(obs))
else:
aesara_res = logcdf(rv_var, at.as_tensor(obs))
aesara_res_val = aesara_res.eval(dist_params)

numpy_res = np.asarray(test_fn(obs, *dist_params.values()))
Expand Down Expand Up @@ -83,6 +86,26 @@ def scipy_logprob(obs, l, u):
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)


@pytest.mark.parametrize(
"dist_params, obs, size",
[
((0, 1), np.array([-1, 0, 0.5, 1, 2], dtype=np.float64), ()),
((-2, -1), np.array([-3, -2, -0.5, -1, 0], dtype=np.float64), ()),
],
)
def test_uniform_logcdf(dist_params, obs, size):

dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.uniform(*dist_params_at, size=size_at)

def scipy_logcdf(obs, l, u):
return stats.uniform.logcdf(obs, loc=l, scale=u - l)

scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logcdf, test_logcdf=True)


@pytest.mark.parametrize(
"dist_params, obs, size",
[
Expand All @@ -101,6 +124,26 @@ def test_normal_logprob(dist_params, obs, size):
scipy_logprob_tester(x, obs, dist_params, test_fn=stats.norm.logpdf)


@pytest.mark.parametrize(
"dist_params, obs, size",
[
((0, 1), np.array([0, 0.5, 1, -1], dtype=np.float64), ()),
((-1, 20), np.array([0, 0.5, 1, -1], dtype=np.float64), ()),
((-1, 20), np.array([0, 0.5, 1, -1], dtype=np.float64), (2, 3)),
],
)
def test_normal_logcdf(dist_params, obs, size):

dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.normal(*dist_params_at, size=size_at)

scipy_logprob_tester(
x, obs, dist_params, test_fn=stats.norm.logcdf, test_logcdf=True
)


@pytest.mark.parametrize(
"dist_params, obs, size",
[
Expand Down Expand Up @@ -620,6 +663,38 @@ def scipy_logprob(obs, mu):
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)


@pytest.mark.parametrize(
"dist_params, obs, size, error",
[
((-1,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (), True),
((1.0,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (), False),
((0.5,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (3, 2), False),
(
(np.array([0.01, 0.2, 200]),),
np.array([-1, 1, 84], dtype=np.int64),
(),
False,
),
],
)
def test_poisson_logcdf(dist_params, obs, size, error):

dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.poisson(*dist_params_at, size=size_at)

cm = contextlib.suppress() if not error else pytest.raises(AssertionError)

def scipy_logcdf(obs, mu):
return stats.poisson.logcdf(obs, mu)

with cm:
scipy_logprob_tester(
x, obs, dist_params, test_fn=scipy_logcdf, test_logcdf=True
)


@pytest.mark.parametrize(
"dist_params, obs, size, error",
[
Expand Down

0 comments on commit 9e9d530

Please sign in to comment.