From e16ef750dacc4d80462c97b5d511a03baccd522c Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 22 Jun 2021 13:01:27 +0200 Subject: [PATCH] Implement RV censoring logprob and opt --- aeppl/opt.py | 6 ++- aeppl/truncation.py | 95 ++++++++++++++++++++++++++++++++++++++++ tests/test_truncation.py | 42 ++++++++++++++++++ 3 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 aeppl/truncation.py create mode 100644 tests/test_truncation.py diff --git a/aeppl/opt.py b/aeppl/opt.py index b42893c7..89a80b6a 100644 --- a/aeppl/opt.py +++ b/aeppl/opt.py @@ -5,7 +5,7 @@ from aesara.compile.mode import optdb from aesara.graph.features import Feature from aesara.graph.op import compute_test_value -from aesara.graph.opt import EquilibriumOptimizer, local_optimizer +from aesara.graph.opt import EquilibriumOptimizer, local_optimizer, out2in from aesara.graph.optdb import SequenceDB from aesara.tensor.extra_ops import BroadcastTo from aesara.tensor.random.op import RandomVariable @@ -21,6 +21,7 @@ ) from aesara.tensor.var import TensorVariable +from aeppl.truncation import censor_rvs from aeppl.utils import indices_from_subtensor inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1) @@ -182,4 +183,5 @@ def naive_bcast_rv_lift(fgraph, node): logprob_canonicalize.register("canonicalize", optdb["canonicalize"], -10, "basic") -logprob_canonicalize.register("rvsinker", RVSinker(), -1, "basic") +logprob_canonicalize.register("rvsinker", RVSinker(), -5, "basic") +logprob_canonicalize.register("censor_rvs", out2in(censor_rvs), -1, "basic") diff --git a/aeppl/truncation.py b/aeppl/truncation.py new file mode 100644 index 00000000..dc6315b9 --- /dev/null +++ b/aeppl/truncation.py @@ -0,0 +1,95 @@ +from typing import List, Optional + +import aesara.tensor as at +import numpy as np +from aesara.assert_op import Assert +from aesara.graph.basic import Node +from aesara.graph.fg import FunctionGraph +from aesara.graph.opt import local_optimizer +from aesara.scalar.basic import Clip +from aesara.tensor.elemwise import Elemwise +from aesara.tensor.random.op import RandomVariable + +from aeppl.logprob import _logcdf, _logprob + + +# TODO: Add interval transform +class CensoredRV(RandomVariable): + r"""A base class for censored `RandomVariable`\s.""" + + def __init__(self, *args, base_op, **kwargs): + self.base_op = base_op + super().__init__(*args, **kwargs) + + +@_logprob.register(CensoredRV) +def censor_logprob(op, value, *inputs, name=None, **kwargs): + + *rv_params, lower_bound, upper_bound = inputs + logprob = _logprob(op.base_op, value, *rv_params, name=name, **kwargs) + logcdf = _logcdf(op.base_op, value, *rv_params, name=name, **kwargs) + + # TODO: Check constant -inf + # TODO: Exact bound check might be problematic + res = at.switch( + at.eq(value, lower_bound), + logcdf, + at.switch( + at.eq(value, upper_bound), + at.log(1 - at.exp(logcdf)), + logprob, + ), + ) + res = at.switch( + at.or_(at.lt(value, lower_bound), at.gt(value, upper_bound)), + -np.inf, + res, + ) + + res = Assert("lower_bound <= upper_bound")( + res, at.all(at.le(lower_bound, upper_bound)) + ) + + if name: + res.name = f"{name}_censored_logprob" + + return res + + +@local_optimizer(tracks=[Elemwise]) +def censor_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: + + rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return + + # TODO: Allow for one sided-truncation with set_subtensor graph (x[x>ub] = ub) + if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, Clip): + return + + base_var, lower_bound, upper_bound = node.inputs + base_op = base_var.owner.op + + if not isinstance(base_op, RandomVariable): + return + + # clipped_var = node.ouputs[0] + # clipped_value = rv_map_feature.rv_values.pop(clipped_var) + + censored_rv = CensoredRV( + "censored", + base_op.ndim_supp, + list(base_op.ndims_params) + [0, 0], + base_op.dtype, + inplace=False, + base_op=base_op, + ) + censored_node = censored_rv.make_node( + *base_var.owner.inputs, + lower_bound, + upper_bound, + ) + + # assert censored_node.outputs[1].type == clipped_var.type + return [censored_node.outputs[1]] diff --git a/tests/test_truncation.py b/tests/test_truncation.py new file mode 100644 index 00000000..d7701c94 --- /dev/null +++ b/tests/test_truncation.py @@ -0,0 +1,42 @@ +import aesara +import aesara.tensor as at +import numpy as np +import scipy.stats as st + +from aeppl import joint_logprob + + +def test_uniform_censoring(): + x_rv = at.random.uniform(-4, 5) + cens_x_rv = at.clip(x_rv, -1, 1) + cens_x = cens_x_rv.type() + + logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x}) + logp_fn = aesara.function([cens_x], logp) + + ref_scipy = st.uniform(-4, 9) + + assert logp_fn(-5) == -np.inf + assert logp_fn(6) == -np.inf + + assert np.isclose(logp_fn(-1), ref_scipy.logcdf(-1)) + assert np.isclose(logp_fn(5), ref_scipy.logsf(5)) + assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) + + +def test_normal_censoring(): + x_rv = at.random.normal(0.5, 1) + cens_x_rv = at.clip(x_rv, -2, 2) + cens_x = cens_x_rv.type() + + logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x}) + logp_fn = aesara.function([cens_x], logp) + + ref_scipy = st.norm(0.5, 1) + + assert logp_fn(-3) == -np.inf + assert logp_fn(3) == -np.inf + + assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2)) + assert np.isclose(logp_fn(2), ref_scipy.logsf(2)) + assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))