Skip to content

Commit

Permalink
Implement RV censoring logprob and opt
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Jun 22, 2021
1 parent e3930a2 commit e16ef75
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 2 deletions.
6 changes: 4 additions & 2 deletions aeppl/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aesara.compile.mode import optdb
from aesara.graph.features import Feature
from aesara.graph.op import compute_test_value
from aesara.graph.opt import EquilibriumOptimizer, local_optimizer
from aesara.graph.opt import EquilibriumOptimizer, local_optimizer, out2in
from aesara.graph.optdb import SequenceDB
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.random.op import RandomVariable
Expand All @@ -21,6 +21,7 @@
)
from aesara.tensor.var import TensorVariable

from aeppl.truncation import censor_rvs
from aeppl.utils import indices_from_subtensor

inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
Expand Down Expand Up @@ -182,4 +183,5 @@ def naive_bcast_rv_lift(fgraph, node):


logprob_canonicalize.register("canonicalize", optdb["canonicalize"], -10, "basic")
logprob_canonicalize.register("rvsinker", RVSinker(), -1, "basic")
logprob_canonicalize.register("rvsinker", RVSinker(), -5, "basic")
logprob_canonicalize.register("censor_rvs", out2in(censor_rvs), -1, "basic")
95 changes: 95 additions & 0 deletions aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import List, Optional

import aesara.tensor as at
import numpy as np
from aesara.assert_op import Assert
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer
from aesara.scalar.basic import Clip
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable

from aeppl.logprob import _logcdf, _logprob


# TODO: Add interval transform
class CensoredRV(RandomVariable):
r"""A base class for censored `RandomVariable`\s."""

def __init__(self, *args, base_op, **kwargs):
self.base_op = base_op
super().__init__(*args, **kwargs)


@_logprob.register(CensoredRV)
def censor_logprob(op, value, *inputs, name=None, **kwargs):

*rv_params, lower_bound, upper_bound = inputs
logprob = _logprob(op.base_op, value, *rv_params, name=name, **kwargs)
logcdf = _logcdf(op.base_op, value, *rv_params, name=name, **kwargs)

# TODO: Check constant -inf
# TODO: Exact bound check might be problematic
res = at.switch(
at.eq(value, lower_bound),
logcdf,
at.switch(
at.eq(value, upper_bound),
at.log(1 - at.exp(logcdf)),
logprob,
),
)
res = at.switch(
at.or_(at.lt(value, lower_bound), at.gt(value, upper_bound)),
-np.inf,
res,
)

res = Assert("lower_bound <= upper_bound")(
res, at.all(at.le(lower_bound, upper_bound))
)

if name:
res.name = f"{name}_censored_logprob"

return res


@local_optimizer(tracks=[Elemwise])
def censor_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return

# TODO: Allow for one sided-truncation with set_subtensor graph (x[x>ub] = ub)
if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, Clip):
return

base_var, lower_bound, upper_bound = node.inputs
base_op = base_var.owner.op

if not isinstance(base_op, RandomVariable):
return

# clipped_var = node.ouputs[0]
# clipped_value = rv_map_feature.rv_values.pop(clipped_var)

censored_rv = CensoredRV(
"censored",
base_op.ndim_supp,
list(base_op.ndims_params) + [0, 0],
base_op.dtype,
inplace=False,
base_op=base_op,
)
censored_node = censored_rv.make_node(
*base_var.owner.inputs,
lower_bound,
upper_bound,
)

# assert censored_node.outputs[1].type == clipped_var.type
return [censored_node.outputs[1]]
42 changes: 42 additions & 0 deletions tests/test_truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import aesara
import aesara.tensor as at
import numpy as np
import scipy.stats as st

from aeppl import joint_logprob


def test_uniform_censoring():
x_rv = at.random.uniform(-4, 5)
cens_x_rv = at.clip(x_rv, -1, 1)
cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
logp_fn = aesara.function([cens_x], logp)

ref_scipy = st.uniform(-4, 9)

assert logp_fn(-5) == -np.inf
assert logp_fn(6) == -np.inf

assert np.isclose(logp_fn(-1), ref_scipy.logcdf(-1))
assert np.isclose(logp_fn(5), ref_scipy.logsf(5))
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))


def test_normal_censoring():
x_rv = at.random.normal(0.5, 1)
cens_x_rv = at.clip(x_rv, -2, 2)
cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
logp_fn = aesara.function([cens_x], logp)

ref_scipy = st.norm(0.5, 1)

assert logp_fn(-3) == -np.inf
assert logp_fn(3) == -np.inf

assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2))
assert np.isclose(logp_fn(2), ref_scipy.logsf(2))
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))

0 comments on commit e16ef75

Please sign in to comment.