Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement censored log-probabilities via the Clip Op #22

Merged
merged 2 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aeppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
import aeppl.cumsum
import aeppl.mixture
import aeppl.scan
import aeppl.truncation

# isort: on
73 changes: 72 additions & 1 deletion aeppl/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ def logprob(rv_var, *rv_values, **kwargs):
return logprob


def logcdf(rv_var, rv_value, **kwargs):
"""Create a graph for the logcdf of a ``RandomVariable``."""
logcdf = _logcdf(
rv_var.owner.op, rv_value, *rv_var.owner.inputs, name=rv_var.name, **kwargs
)

if rv_var.name:
logcdf.name = f"{rv_var.name}_logcdf"

return logcdf


@singledispatch
def _logprob(
op: Op,
Expand All @@ -52,7 +64,23 @@ def _logprob(
for a ``RandomVariable``, register a new function on this dispatcher.

"""
raise NotImplementedError()
raise NotImplementedError(f"Logprob method not implemented for {op}")


@singledispatch
def _logcdf(
op: Op,
value: TensorVariable,
*inputs: TensorVariable,
**kwargs,
):
"""Create a graph for the logcdf of a ``RandomVariable``.

This function dispatches on the type of ``op``, which should be a subclass
of ``RandomVariable``. If you want to implement new logcdf graphs
for a ``RandomVariable``, register a new function on this dispatcher.
"""
raise NotImplementedError(f"Logcdf method not implemented for {op}")


@_logprob.register(arb.UniformRV)
Expand All @@ -66,6 +94,24 @@ def uniform_logprob(op, values, *inputs, **kwargs):
)


@_logcdf.register(arb.UniformRV)
def uniform_logcdf(op, value, *inputs, **kwargs):
lower, upper = inputs[3:]

res = at.switch(
at.lt(value, lower),
-np.inf,
at.switch(
at.lt(value, upper),
at.log(value - lower) - at.log(upper - lower),
0,
),
)

res = Assert("lower <= upper")(res, at.all(at.le(lower, upper)))
return res


@_logprob.register(arb.NormalRV)
def normal_logprob(op, values, *inputs, **kwargs):
(value,) = values
Expand All @@ -79,6 +125,21 @@ def normal_logprob(op, values, *inputs, **kwargs):
return res


@_logcdf.register(arb.NormalRV)
def normal_logcdf(op, value, *inputs, **kwargs):
mu, sigma = inputs[3:]

z = (value - mu) / sigma
res = at.switch(
at.lt(z, -1.0),
at.log(at.erfcx(-z / at.sqrt(2.0)) / 2.0) - at.sqr(z) / 2.0,
at.log1p(-at.erfc(z / at.sqrt(2.0)) / 2.0),
)

res = Assert("sigma > 0")(res, at.all(at.gt(sigma, 0.0)))
return res


@_logprob.register(arb.HalfNormalRV)
def halfnormal_logprob(op, values, *inputs, **kwargs):
(value,) = values
Expand Down Expand Up @@ -346,6 +407,16 @@ def poisson_logprob(op, values, *inputs, **kwargs):
return res


@_logcdf.register(arb.PoissonRV)
def poisson_logcdf(op, value, *inputs, **kwargs):
(mu,) = inputs[3:]
value = at.floor(value)
res = at.log(at.gammaincc(value + 1, mu))
res = at.switch(at.le(0, value), res, -np.inf)
res = Assert("0 <= mu")(res, at.all(at.le(0.0, mu)))
return res


@_logprob.register(arb.NegBinomialRV)
def nbinom_logprob(op, values, *inputs, **kwargs):
(value,) = values
Expand Down
126 changes: 126 additions & 0 deletions aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import warnings
from typing import List, Optional

import aesara.tensor as at
import numpy as np
from aesara.assert_op import Assert
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer
from aesara.scalar.basic import Clip
from aesara.scalar.basic import clip as scalar_clip
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.var import TensorConstant

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.logprob import _logcdf, _logprob
from aeppl.opt import rv_sinking_db


class CensoredRV(Elemwise):
"""A placeholder used to specify a log-likelihood for a censored RV sub-graph."""


MeasurableVariable.register(CensoredRV)


@local_optimizer(tracks=[Elemwise])
def find_censored_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[CensoredRV]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, CensoredRV):
return None # pragma: no cover

if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)):
return None

clipped_var = node.outputs[0]
if clipped_var not in rv_map_feature.rv_values:
return None

base_var, lower_bound, upper_bound = node.inputs

if not (base_var.owner and isinstance(base_var.owner.op, MeasurableVariable)):
return None

if base_var in rv_map_feature.rv_values:
warnings.warn(
f"Value variables were assigned to both the input ({base_var}) and "
f"output ({clipped_var}) of a censored random variable."
)
return None

# Replace bounds by `+-inf` if `y = clip(x, x, ?)` or `y=clip(x, ?, x)`
# This is used in `censor_logprob` to generate a more succint logprob graph
# for one-sided censored random variables
lower_bound = lower_bound if (lower_bound is not base_var) else at.constant(-np.inf)
upper_bound = upper_bound if (upper_bound is not base_var) else at.constant(np.inf)

censored_op = CensoredRV(scalar_clip)
# Make base_var unmeasurable
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)
censored_rv_node = censored_op.make_node(
unmeasurable_base_var, lower_bound, upper_bound
)
censored_rv = censored_rv_node.outputs[0]

censored_rv.name = clipped_var.name

return [censored_rv]


rv_sinking_db.register("find_censored_rvs", find_censored_rvs, -5, "basic")


@_logprob.register(CensoredRV)
def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
(value,) = values

base_rv_op = base_rv.owner.op
base_rv_inputs = base_rv.owner.inputs

logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs)
logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs)

if base_rv_op.name:
logprob.name = f"{base_rv_op}_logprob"
logcdf.name = f"{base_rv_op}_logcdf"

is_lower_bounded, is_upper_bounded = False, False
if not (
isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))
):
is_upper_bounded = True

logccdf = at.log(1 - at.exp(logcdf))
# For right censored discrete RVs, we need to add an extra term
# corresponding to the pmf at the upper bound
if base_rv_op.dtype == "int64":
logccdf = at.logaddexp(logccdf, logprob)

logprob = at.switch(
at.eq(value, upper_bound),
logccdf,
at.switch(at.gt(value, upper_bound), -np.inf, logprob),
)
if not (
isinstance(lower_bound, TensorConstant)
and np.all(np.isneginf(lower_bound.value))
):
is_lower_bounded = True
logprob = at.switch(
at.eq(value, lower_bound),
logcdf,
at.switch(at.lt(value, lower_bound), -np.inf, logprob),
)

if is_lower_bounded and is_upper_bounded:
logprob = Assert("lower_bound <= upper_bound")(
logprob, at.all(at.le(lower_bound, upper_bound))
)

return logprob
81 changes: 78 additions & 3 deletions tests/test_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import scipy.stats as stats

from aeppl.logprob import logprob
from aeppl.logprob import logcdf, logprob

# @pytest.fixture(scope="module", autouse=True)
# def set_aesara_flags():
Expand Down Expand Up @@ -33,7 +33,7 @@ def create_aesara_params(dist_params, obs, size):


def scipy_logprob_tester(
rv_var, obs, dist_params, test_fn=None, check_broadcastable=True
rv_var, obs, dist_params, test_fn=None, check_broadcastable=True, test_logcdf=False
):
"""Test for correspondence between `RandomVariable` and NumPy shape and
broadcast dimensions.
Expand All @@ -46,7 +46,10 @@ def scipy_logprob_tester(

test_fn = getattr(stats, name)

aesara_res = logprob(rv_var, at.as_tensor(obs))
if not test_logcdf:
aesara_res = logprob(rv_var, at.as_tensor(obs))
else:
aesara_res = logcdf(rv_var, at.as_tensor(obs))
aesara_res_val = aesara_res.eval(dist_params)

numpy_res = np.asarray(test_fn(obs, *dist_params.values()))
Expand Down Expand Up @@ -83,6 +86,26 @@ def scipy_logprob(obs, l, u):
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)


@pytest.mark.parametrize(
"dist_params, obs, size",
[
((0, 1), np.array([-1, 0, 0.5, 1, 2], dtype=np.float64), ()),
((-2, -1), np.array([-3, -2, -0.5, -1, 0], dtype=np.float64), ()),
],
)
def test_uniform_logcdf(dist_params, obs, size):

dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.uniform(*dist_params_at, size=size_at)

def scipy_logcdf(obs, l, u):
return stats.uniform.logcdf(obs, loc=l, scale=u - l)

scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logcdf, test_logcdf=True)


@pytest.mark.parametrize(
"dist_params, obs, size",
[
Expand All @@ -101,6 +124,26 @@ def test_normal_logprob(dist_params, obs, size):
scipy_logprob_tester(x, obs, dist_params, test_fn=stats.norm.logpdf)


@pytest.mark.parametrize(
"dist_params, obs, size",
[
((0, 1), np.array([0, 0.5, 1, -1], dtype=np.float64), ()),
((-1, 20), np.array([0, 0.5, 1, -1], dtype=np.float64), ()),
((-1, 20), np.array([0, 0.5, 1, -1], dtype=np.float64), (2, 3)),
],
)
def test_normal_logcdf(dist_params, obs, size):

dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.normal(*dist_params_at, size=size_at)

scipy_logprob_tester(
x, obs, dist_params, test_fn=stats.norm.logcdf, test_logcdf=True
)


@pytest.mark.parametrize(
"dist_params, obs, size",
[
Expand Down Expand Up @@ -620,6 +663,38 @@ def scipy_logprob(obs, mu):
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)


@pytest.mark.parametrize(
"dist_params, obs, size, error",
[
((-1,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (), True),
((1.0,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (), False),
((0.5,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (3, 2), False),
(
(np.array([0.01, 0.2, 200]),),
np.array([-1, 1, 84], dtype=np.int64),
(),
False,
),
],
)
def test_poisson_logcdf(dist_params, obs, size, error):

dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.poisson(*dist_params_at, size=size_at)

cm = contextlib.suppress() if not error else pytest.raises(AssertionError)

def scipy_logcdf(obs, mu):
return stats.poisson.logcdf(obs, mu)

with cm:
scipy_logprob_tester(
x, obs, dist_params, test_fn=scipy_logcdf, test_logcdf=True
)


@pytest.mark.parametrize(
"dist_params, obs, size, error",
[
Expand Down
Loading