diff --git a/aeppl/__init__.py b/aeppl/__init__.py index ac215f51..42200a91 100644 --- a/aeppl/__init__.py +++ b/aeppl/__init__.py @@ -14,5 +14,6 @@ import aeppl.cumsum import aeppl.mixture import aeppl.scan +import aeppl.truncation # isort: on diff --git a/aeppl/logprob.py b/aeppl/logprob.py index 92821c88..d168d1cc 100644 --- a/aeppl/logprob.py +++ b/aeppl/logprob.py @@ -38,6 +38,18 @@ def logprob(rv_var, *rv_values, **kwargs): return logprob +def logcdf(rv_var, rv_value, **kwargs): + """Create a graph for the logcdf of a ``RandomVariable``.""" + logcdf = _logcdf( + rv_var.owner.op, rv_value, *rv_var.owner.inputs, name=rv_var.name, **kwargs + ) + + if rv_var.name: + logcdf.name = f"{rv_var.name}_logcdf" + + return logcdf + + @singledispatch def _logprob( op: Op, @@ -52,7 +64,23 @@ def _logprob( for a ``RandomVariable``, register a new function on this dispatcher. """ - raise NotImplementedError() + raise NotImplementedError(f"Logprob method not implemented for {op}") + + +@singledispatch +def _logcdf( + op: Op, + value: TensorVariable, + *inputs: TensorVariable, + **kwargs, +): + """Create a graph for the logcdf of a ``RandomVariable``. + + This function dispatches on the type of ``op``, which should be a subclass + of ``RandomVariable``. If you want to implement new logcdf graphs + for a ``RandomVariable``, register a new function on this dispatcher. + """ + raise NotImplementedError(f"Logcdf method not implemented for {op}") @_logprob.register(arb.UniformRV) @@ -66,6 +94,24 @@ def uniform_logprob(op, values, *inputs, **kwargs): ) +@_logcdf.register(arb.UniformRV) +def uniform_logcdf(op, value, *inputs, **kwargs): + lower, upper = inputs[3:] + + res = at.switch( + at.lt(value, lower), + -np.inf, + at.switch( + at.lt(value, upper), + at.log(value - lower) - at.log(upper - lower), + 0, + ), + ) + + res = Assert("lower <= upper")(res, at.all(at.le(lower, upper))) + return res + + @_logprob.register(arb.NormalRV) def normal_logprob(op, values, *inputs, **kwargs): (value,) = values @@ -79,6 +125,21 @@ def normal_logprob(op, values, *inputs, **kwargs): return res +@_logcdf.register(arb.NormalRV) +def normal_logcdf(op, value, *inputs, **kwargs): + mu, sigma = inputs[3:] + + z = (value - mu) / sigma + res = at.switch( + at.lt(z, -1.0), + at.log(at.erfcx(-z / at.sqrt(2.0)) / 2.0) - at.sqr(z) / 2.0, + at.log1p(-at.erfc(z / at.sqrt(2.0)) / 2.0), + ) + + res = Assert("sigma > 0")(res, at.all(at.gt(sigma, 0.0))) + return res + + @_logprob.register(arb.HalfNormalRV) def halfnormal_logprob(op, values, *inputs, **kwargs): (value,) = values @@ -346,6 +407,16 @@ def poisson_logprob(op, values, *inputs, **kwargs): return res +@_logcdf.register(arb.PoissonRV) +def poisson_logcdf(op, value, *inputs, **kwargs): + (mu,) = inputs[3:] + value = at.floor(value) + res = at.log(at.gammaincc(value + 1, mu)) + res = at.switch(at.le(0, value), res, -np.inf) + res = Assert("0 <= mu")(res, at.all(at.le(0.0, mu))) + return res + + @_logprob.register(arb.NegBinomialRV) def nbinom_logprob(op, values, *inputs, **kwargs): (value,) = values diff --git a/aeppl/truncation.py b/aeppl/truncation.py new file mode 100644 index 00000000..b7647cd8 --- /dev/null +++ b/aeppl/truncation.py @@ -0,0 +1,126 @@ +import warnings +from typing import List, Optional + +import aesara.tensor as at +import numpy as np +from aesara.assert_op import Assert +from aesara.graph.basic import Node +from aesara.graph.fg import FunctionGraph +from aesara.graph.opt import local_optimizer +from aesara.scalar.basic import Clip +from aesara.scalar.basic import clip as scalar_clip +from aesara.tensor.elemwise import Elemwise +from aesara.tensor.var import TensorConstant + +from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs +from aeppl.logprob import _logcdf, _logprob +from aeppl.opt import rv_sinking_db + + +class CensoredRV(Elemwise): + """A placeholder used to specify a log-likelihood for a censored RV sub-graph.""" + + +MeasurableVariable.register(CensoredRV) + + +@local_optimizer(tracks=[Elemwise]) +def find_censored_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[CensoredRV]]: + # TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) + + rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + if isinstance(node.op, CensoredRV): + return None # pragma: no cover + + if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)): + return None + + clipped_var = node.outputs[0] + if clipped_var not in rv_map_feature.rv_values: + return None + + base_var, lower_bound, upper_bound = node.inputs + + if not (base_var.owner and isinstance(base_var.owner.op, MeasurableVariable)): + return None + + if base_var in rv_map_feature.rv_values: + warnings.warn( + f"Value variables were assigned to both the input ({base_var}) and " + f"output ({clipped_var}) of a censored random variable." + ) + return None + + # Replace bounds by `+-inf` if `y = clip(x, x, ?)` or `y=clip(x, ?, x)` + # This is used in `censor_logprob` to generate a more succint logprob graph + # for one-sided censored random variables + lower_bound = lower_bound if (lower_bound is not base_var) else at.constant(-np.inf) + upper_bound = upper_bound if (upper_bound is not base_var) else at.constant(np.inf) + + censored_op = CensoredRV(scalar_clip) + # Make base_var unmeasurable + unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner) + censored_rv_node = censored_op.make_node( + unmeasurable_base_var, lower_bound, upper_bound + ) + censored_rv = censored_rv_node.outputs[0] + + censored_rv.name = clipped_var.name + + return [censored_rv] + + +rv_sinking_db.register("find_censored_rvs", find_censored_rvs, -5, "basic") + + +@_logprob.register(CensoredRV) +def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs): + (value,) = values + + base_rv_op = base_rv.owner.op + base_rv_inputs = base_rv.owner.inputs + + logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs) + logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs) + + if base_rv_op.name: + logprob.name = f"{base_rv_op}_logprob" + logcdf.name = f"{base_rv_op}_logcdf" + + is_lower_bounded, is_upper_bounded = False, False + if not ( + isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value)) + ): + is_upper_bounded = True + + logccdf = at.log(1 - at.exp(logcdf)) + # For right censored discrete RVs, we need to add an extra term + # corresponding to the pmf at the upper bound + if base_rv_op.dtype == "int64": + logccdf = at.logaddexp(logccdf, logprob) + + logprob = at.switch( + at.eq(value, upper_bound), + logccdf, + at.switch(at.gt(value, upper_bound), -np.inf, logprob), + ) + if not ( + isinstance(lower_bound, TensorConstant) + and np.all(np.isneginf(lower_bound.value)) + ): + is_lower_bounded = True + logprob = at.switch( + at.eq(value, lower_bound), + logcdf, + at.switch(at.lt(value, lower_bound), -np.inf, logprob), + ) + + if is_lower_bounded and is_upper_bounded: + logprob = Assert("lower_bound <= upper_bound")( + logprob, at.all(at.le(lower_bound, upper_bound)) + ) + + return logprob diff --git a/tests/test_logprob.py b/tests/test_logprob.py index 62b05175..f98dfda1 100644 --- a/tests/test_logprob.py +++ b/tests/test_logprob.py @@ -5,7 +5,7 @@ import pytest import scipy.stats as stats -from aeppl.logprob import logprob +from aeppl.logprob import logcdf, logprob # @pytest.fixture(scope="module", autouse=True) # def set_aesara_flags(): @@ -33,7 +33,7 @@ def create_aesara_params(dist_params, obs, size): def scipy_logprob_tester( - rv_var, obs, dist_params, test_fn=None, check_broadcastable=True + rv_var, obs, dist_params, test_fn=None, check_broadcastable=True, test_logcdf=False ): """Test for correspondence between `RandomVariable` and NumPy shape and broadcast dimensions. @@ -46,7 +46,10 @@ def scipy_logprob_tester( test_fn = getattr(stats, name) - aesara_res = logprob(rv_var, at.as_tensor(obs)) + if not test_logcdf: + aesara_res = logprob(rv_var, at.as_tensor(obs)) + else: + aesara_res = logcdf(rv_var, at.as_tensor(obs)) aesara_res_val = aesara_res.eval(dist_params) numpy_res = np.asarray(test_fn(obs, *dist_params.values())) @@ -83,6 +86,26 @@ def scipy_logprob(obs, l, u): scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob) +@pytest.mark.parametrize( + "dist_params, obs, size", + [ + ((0, 1), np.array([-1, 0, 0.5, 1, 2], dtype=np.float64), ()), + ((-2, -1), np.array([-3, -2, -0.5, -1, 0], dtype=np.float64), ()), + ], +) +def test_uniform_logcdf(dist_params, obs, size): + + dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size) + dist_params = dict(zip(dist_params_at, dist_params)) + + x = at.random.uniform(*dist_params_at, size=size_at) + + def scipy_logcdf(obs, l, u): + return stats.uniform.logcdf(obs, loc=l, scale=u - l) + + scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logcdf, test_logcdf=True) + + @pytest.mark.parametrize( "dist_params, obs, size", [ @@ -101,6 +124,26 @@ def test_normal_logprob(dist_params, obs, size): scipy_logprob_tester(x, obs, dist_params, test_fn=stats.norm.logpdf) +@pytest.mark.parametrize( + "dist_params, obs, size", + [ + ((0, 1), np.array([0, 0.5, 1, -1], dtype=np.float64), ()), + ((-1, 20), np.array([0, 0.5, 1, -1], dtype=np.float64), ()), + ((-1, 20), np.array([0, 0.5, 1, -1], dtype=np.float64), (2, 3)), + ], +) +def test_normal_logcdf(dist_params, obs, size): + + dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size) + dist_params = dict(zip(dist_params_at, dist_params)) + + x = at.random.normal(*dist_params_at, size=size_at) + + scipy_logprob_tester( + x, obs, dist_params, test_fn=stats.norm.logcdf, test_logcdf=True + ) + + @pytest.mark.parametrize( "dist_params, obs, size", [ @@ -620,6 +663,38 @@ def scipy_logprob(obs, mu): scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob) +@pytest.mark.parametrize( + "dist_params, obs, size, error", + [ + ((-1,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (), True), + ((1.0,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (), False), + ((0.5,), np.array([-1, 0, 1, 100, 10000], dtype=np.int64), (3, 2), False), + ( + (np.array([0.01, 0.2, 200]),), + np.array([-1, 1, 84], dtype=np.int64), + (), + False, + ), + ], +) +def test_poisson_logcdf(dist_params, obs, size, error): + + dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size) + dist_params = dict(zip(dist_params_at, dist_params)) + + x = at.random.poisson(*dist_params_at, size=size_at) + + cm = contextlib.suppress() if not error else pytest.raises(AssertionError) + + def scipy_logcdf(obs, mu): + return stats.poisson.logcdf(obs, mu) + + with cm: + scipy_logprob_tester( + x, obs, dist_params, test_fn=scipy_logcdf, test_logcdf=True + ) + + @pytest.mark.parametrize( "dist_params, obs, size, error", [ diff --git a/tests/test_truncation.py b/tests/test_truncation.py new file mode 100644 index 00000000..49677cc9 --- /dev/null +++ b/tests/test_truncation.py @@ -0,0 +1,192 @@ +import aesara +import aesara.tensor as at +import numpy as np +import pytest +import scipy as sp +import scipy.stats as st + +from aeppl import factorized_joint_logprob, joint_logprob +from aeppl.transforms import LogTransform, TransformValuesOpt +from tests.utils import assert_no_rvs + + +@aesara.config.change_flags(compute_test_value="raise") +def test_continuous_rv_censoring(): + x_rv = at.random.normal(0.5, 1) + cens_x_rv = at.clip(x_rv, -2, 2) + + cens_x_vv = cens_x_rv.clone() + cens_x_vv.tag.test_value = 0 + + logp = joint_logprob({cens_x_rv: cens_x_vv}) + assert_no_rvs(logp) + + logp_fn = aesara.function([cens_x_vv], logp) + ref_scipy = st.norm(0.5, 1) + + assert logp_fn(-3) == -np.inf + assert logp_fn(3) == -np.inf + + assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2)) + assert np.isclose(logp_fn(2), ref_scipy.logsf(2)) + assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) + + +def test_discrete_rv_censoring(): + x_rv = at.random.poisson(2) + cens_x_rv = at.clip(x_rv, 1, 4) + + cens_x_vv = cens_x_rv.clone() + + logp = joint_logprob({cens_x_rv: cens_x_vv}) + assert_no_rvs(logp) + + logp_fn = aesara.function([cens_x_vv], logp) + ref_scipy = st.poisson(2) + + assert logp_fn(0) == -np.inf + assert logp_fn(5) == -np.inf + + assert np.isclose(logp_fn(1), ref_scipy.logcdf(1)) + assert np.isclose(logp_fn(4), np.logaddexp(ref_scipy.logsf(4), ref_scipy.logpmf(4))) + assert np.isclose(logp_fn(2), ref_scipy.logpmf(2)) + + +def test_one_sided_censoring(): + x_rv = at.random.normal(0, 1) + lb_cens_x_rv = at.clip(x_rv, -1, x_rv) + ub_cens_x_rv = at.clip(x_rv, x_rv, 1) + + lb_cens_x_vv = lb_cens_x_rv.clone() + ub_cens_x_vv = ub_cens_x_rv.clone() + + lb_logp = joint_logprob({lb_cens_x_rv: lb_cens_x_vv}) + ub_logp = joint_logprob({ub_cens_x_rv: ub_cens_x_vv}) + assert_no_rvs(lb_logp) + assert_no_rvs(ub_logp) + + logp_fn = aesara.function([lb_cens_x_vv, ub_cens_x_vv], [lb_logp, ub_logp]) + ref_scipy = st.norm(0, 1) + + assert np.all(np.array(logp_fn(-2, 2)) == -np.inf) + assert np.all(np.array(logp_fn(2, -2)) != -np.inf) + np.testing.assert_almost_equal(logp_fn(-1, 1), ref_scipy.logcdf(-1)) + np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1)) + + +def test_useless_censoring(): + x_rv = at.random.normal(0.5, 1, size=3) + cens_x_rv = at.clip(x_rv, x_rv, x_rv) + + cens_x_vv = cens_x_rv.clone() + + logp = joint_logprob({cens_x_rv: cens_x_vv}, sum=False) + assert_no_rvs(logp) + + logp_fn = aesara.function([cens_x_vv], logp) + ref_scipy = st.norm(0.5, 1) + + np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2])) + + +def test_random_censoring(): + lb_rv = at.random.normal(0, 1, size=2) + x_rv = at.random.normal(0, 2) + cens_x_rv = at.clip(x_rv, lb_rv, [1, 1]) + + lb_vv = lb_rv.clone() + cens_x_vv = cens_x_rv.clone() + logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}, sum=False) + assert_no_rvs(logp) + + logp_fn = aesara.function([lb_vv, cens_x_vv], logp) + res = logp_fn([0, -1], [-1, -1]) + assert res[0] == -np.inf + assert res[1] != -np.inf + + +def test_broadcasted_censoring_constant(): + lb_rv = at.random.uniform(0, 1) + x_rv = at.random.normal(0, 2) + cens_x_rv = at.clip(x_rv, lb_rv, [1, 1]) + + lb_vv = lb_rv.clone() + cens_x_vv = cens_x_rv.clone() + + logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}) + assert_no_rvs(logp) + + +def test_broadcasted_censoring_random(): + lb_rv = at.random.normal(0, 1) + x_rv = at.random.normal(0, 2, size=2) + cens_x_rv = at.clip(x_rv, lb_rv, 1) + + lb_vv = lb_rv.clone() + cens_x_vv = cens_x_rv.clone() + + logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}) + assert_no_rvs(logp) + + +def test_fail_base_and_censored_have_values(): + """Test failure when both base_rv and clipped_rv are given value vars""" + x_rv = at.random.normal(0, 1) + cens_x_rv = at.clip(x_rv, x_rv, 1) + cens_x_rv.name = "cens_x" + + x_vv = x_rv.clone() + cens_x_vv = cens_x_rv.clone() + with pytest.warns(UserWarning): + logp_terms = factorized_joint_logprob({cens_x_rv: cens_x_vv, x_rv: x_vv}) + assert cens_x_vv not in logp_terms + + +def test_fail_multiple_censored_single_base(): + """Test failure when multiple clipped_rvs share a single base_rv""" + base_rv = at.random.normal(0, 1) + cens_rv1 = at.clip(base_rv, -1, 1) + cens_rv1.name = "cens1" + cens_rv2 = at.clip(base_rv, -1, 1) + cens_rv2.name = "cens2" + + cens_vv1 = cens_rv1.clone() + cens_vv2 = cens_rv2.clone() + logp_terms = factorized_joint_logprob({cens_rv1: cens_vv1, cens_rv2: cens_vv2}) + assert cens_rv2 not in logp_terms + + +def test_deterministic_clipping(): + x_rv = at.random.normal(0, 1) + clip = at.clip(x_rv, 0, 0) + y_rv = at.random.normal(clip, 1) + + x_vv = x_rv.clone() + y_vv = y_rv.clone() + logp = joint_logprob({x_rv: x_vv, y_rv: y_vv}) + assert_no_rvs(logp) + + logp_fn = aesara.function([x_vv, y_vv], logp) + assert np.isclose( + logp_fn(-1, 1), + st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1), + ) + + +@pytest.mark.xfail(reason="Transform does not work with Elemwise ops, see #60") +def test_censored_transform(): + x_rv = at.random.normal(0.5, 1) + cens_x_rv = at.clip(x_rv, 0, x_rv) + + cens_x_vv = cens_x_rv.clone() + + transform = TransformValuesOpt({cens_x_vv: LogTransform()}) + logp = joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform) + + cens_x_vv_testval = -1 + obs_logp = logp.eval({cens_x_vv: cens_x_vv_testval}) + exp_logp = ( + sp.stats.norm(0.5, 1).logpdf(np.exp(cens_x_vv_testval)) + cens_x_vv_testval + ) + + assert np.isclose(obs_logp, exp_logp)