Skip to content

Commit

Permalink
Implement RV censoring logprob and opt
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Jun 29, 2021
1 parent e3930a2 commit 673e518
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 2 deletions.
6 changes: 4 additions & 2 deletions aeppl/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aesara.compile.mode import optdb
from aesara.graph.features import Feature
from aesara.graph.op import compute_test_value
from aesara.graph.opt import EquilibriumOptimizer, local_optimizer
from aesara.graph.opt import EquilibriumOptimizer, local_optimizer, out2in
from aesara.graph.optdb import SequenceDB
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.random.op import RandomVariable
Expand All @@ -21,6 +21,7 @@
)
from aesara.tensor.var import TensorVariable

from aeppl.truncation import censor_rvs
from aeppl.utils import indices_from_subtensor

inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
Expand Down Expand Up @@ -182,4 +183,5 @@ def naive_bcast_rv_lift(fgraph, node):


logprob_canonicalize.register("canonicalize", optdb["canonicalize"], -10, "basic")
logprob_canonicalize.register("rvsinker", RVSinker(), -1, "basic")
logprob_canonicalize.register("rvsinker", RVSinker(), -5, "basic")
logprob_canonicalize.register("censor_rvs", out2in(censor_rvs), -1, "basic")
128 changes: 128 additions & 0 deletions aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import List, Optional

import aesara.tensor as at
import numpy as np
from aesara.assert_op import Assert
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer
from aesara.scalar.basic import Clip
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant

from aeppl.logprob import _logcdf, _logprob


# TODO: Add interval transform
class CensoredRV(RandomVariable):
r"""A base class for censored `RandomVariable`\s."""

def __init__(self, *args, base_op, **kwargs):
self.base_op = base_op
super().__init__(*args, **kwargs)


@_logprob.register(CensoredRV)
def censor_logprob(op, value, *inputs, name=None, **kwargs):
# TODO: Exact bound check might be problematic

*rv_params, lower_bound, upper_bound = inputs
logprob = _logprob(op.base_op, value, *rv_params, name=name, **kwargs)
logcdf = _logcdf(op.base_op, value, *rv_params, name=name, **kwargs)

is_lower_bounded, is_upper_bounded = False, False
if not (
isinstance(upper_bound, TensorConstant) and np.all(upper_bound.value == np.inf)
):
is_upper_bounded = True
logprob = at.switch(
at.eq(value, upper_bound),
at.log(1 - at.exp(logcdf)),
at.switch(at.gt(value, upper_bound), -np.inf, logprob),
)

# Lower censored logp
if not (
isinstance(lower_bound, TensorConstant) and np.all(lower_bound.value == -np.inf)
):
is_lower_bounded = True
logprob = at.switch(
at.eq(value, lower_bound),
logcdf,
at.switch(at.lt(value, lower_bound), -np.inf, logprob),
)

if is_lower_bounded and is_upper_bounded:
logprob = Assert("lower_bound <= upper_bound")(
logprob, at.all(at.le(lower_bound, upper_bound))
)

if name:
logprob.name = f"{name}_censored_logprob"

return logprob


@local_optimizer(tracks=[Elemwise])
def censor_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return

if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, Clip):
return

clipped_var = node.outputs[0]
if clipped_var not in rv_map_feature.rv_values:
return

base_var, lower_bound, upper_bound = node.inputs
base_op = base_var.owner.op

if not isinstance(base_op, RandomVariable):
return

# base_var var already has a direct value var
# TODO: Raise informative Error
if base_var in rv_map_feature.rv_values:
return

is_lower_bounded, is_upper_bounded = False, False
if lower_bound is not base_var:
# y = clip(x, x, ?)
is_lower_bounded = True

if upper_bound is not base_var:
# y = clip(x, ?, x)
is_upper_bounded = True

censored_rv = CensoredRV(
"censored",
base_op.ndim_supp,
list(base_op.ndims_params) + [0, 0],
base_op.dtype,
inplace=False,
base_op=base_op,
)

lower_bound = lower_bound if is_lower_bounded else -np.inf
upper_bound = upper_bound if is_upper_bounded else np.inf
# _, lower_bound, upper_bound = at.broadcast_arrays(clipped_var, lower_bound, upper_bound)

censored_node = censored_rv.make_node(
*base_var.owner.inputs,
lower_bound,
upper_bound,
)

censored_node_out = censored_node.outputs[1]
censored_node_out.name = clipped_var.name

if not censored_node_out.type == clipped_var.type:
return

return [censored_node_out]
130 changes: 130 additions & 0 deletions tests/test_truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import aesara
import aesara.tensor as at
import numpy as np
import pytest
import scipy.stats as st

from aeppl import joint_logprob
from tests.utils import assert_no_rvs


def test_uniform_censoring():
x_rv = at.random.uniform(-4, 5)
cens_x_rv = at.clip(x_rv, -1, 1)

cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
assert_no_rvs(logp)

logp_fn = aesara.function([cens_x], logp)
ref_scipy = st.uniform(-4, 9)

assert logp_fn(-5) == -np.inf
assert logp_fn(6) == -np.inf

assert np.isclose(logp_fn(-1), ref_scipy.logcdf(-1))
assert np.isclose(logp_fn(5), ref_scipy.logsf(5))
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))


def test_normal_censoring():
x_rv = at.random.normal(0.5, 1)
cens_x_rv = at.clip(x_rv, -2, 2)

cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
assert_no_rvs(logp)

logp_fn = aesara.function([cens_x], logp)
ref_scipy = st.norm(0.5, 1)

assert logp_fn(-3) == -np.inf
assert logp_fn(3) == -np.inf

assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2))
assert np.isclose(logp_fn(2), ref_scipy.logsf(2))
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))


def test_one_sided_censoring():

x_rv = at.random.normal(0, 1)
lb_cens_x_rv = at.clip(x_rv, -1, x_rv)
ub_cens_x_rv = at.clip(x_rv, x_rv, 1)

lb_cens_x = lb_cens_x_rv.type()
ub_cens_x = ub_cens_x_rv.type()

lb_logp = joint_logprob(lb_cens_x_rv, {lb_cens_x_rv: lb_cens_x})
ub_logp = joint_logprob(ub_cens_x_rv, {ub_cens_x_rv: ub_cens_x})
assert_no_rvs(lb_logp)
assert_no_rvs(ub_logp)

logp_fn = aesara.function([lb_cens_x, ub_cens_x], [lb_logp, ub_logp])
ref_scipy = st.norm(0, 1)

assert np.all(np.array(logp_fn(-2, 2)) == -np.inf)
assert np.all(np.array(logp_fn(2, -2)) != -np.inf)
np.testing.assert_almost_equal(logp_fn(-1, 1), ref_scipy.logcdf(-1))
np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1))


def test_failed_censoring():

x_rv = at.random.normal(0, 1)
cens_x_rv = at.clip(x_rv, x_rv, x_rv)

x = x_rv.type()
cens_x = cens_x_rv.type()
with pytest.raises(NotImplementedError):
joint_logprob(cens_x_rv, {cens_x_rv: cens_x, x_rv: x})


def test_random_censoring():

lb_rv = at.random.normal(0, 1, size=2)
x_rv = at.random.normal(0, 2)
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])

lb = lb_rv.type()
cens_x = cens_x_rv.type()
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp)

logp_fn = aesara.function([lb, cens_x], logp)
res = logp_fn([0, -1], [-1, -1])
assert res[0] == -np.inf
assert res[1] != -np.inf


@pytest.mark.xfail(reason="Broadcasting not properly handled yet")
def test_broadcasted_censoring():

lb_rv = at.random.uniform(0, 1, name="lb_rv")
x_rv = at.random.normal(0, 2, name="x_rv")
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
cens_x_rv.name = "cens_x_rv"

lb = lb_rv.type()
lb.name = "lb"
cens_x = cens_x_rv.type()
cens_x.name = "cens_x"

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp)


@pytest.mark.xfail(reason="Broadcasting not properly handled yet")
def test_broadcasted_censoring2():

lb_rv = at.random.normal(0, 1)
x_rv = at.random.normal(0, 2, size=2)
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])

lb = lb_rv.type()
cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp)

0 comments on commit 673e518

Please sign in to comment.