Skip to content

Commit

Permalink
Implement RV censoring logprob and opt
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Jun 29, 2021
1 parent e3930a2 commit 9dd9052
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 2 deletions.
6 changes: 4 additions & 2 deletions aeppl/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aesara.compile.mode import optdb
from aesara.graph.features import Feature
from aesara.graph.op import compute_test_value
from aesara.graph.opt import EquilibriumOptimizer, local_optimizer
from aesara.graph.opt import EquilibriumOptimizer, local_optimizer, out2in
from aesara.graph.optdb import SequenceDB
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.random.op import RandomVariable
Expand All @@ -21,6 +21,7 @@
)
from aesara.tensor.var import TensorVariable

from aeppl.truncation import censor_rvs
from aeppl.utils import indices_from_subtensor

inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
Expand Down Expand Up @@ -182,4 +183,5 @@ def naive_bcast_rv_lift(fgraph, node):


logprob_canonicalize.register("canonicalize", optdb["canonicalize"], -10, "basic")
logprob_canonicalize.register("rvsinker", RVSinker(), -1, "basic")
logprob_canonicalize.register("rvsinker", RVSinker(), -5, "basic")
logprob_canonicalize.register("censor_rvs", out2in(censor_rvs), -1, "basic")
126 changes: 126 additions & 0 deletions aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import warnings
from typing import List, Optional

import aesara.tensor as at
import numpy as np
from aesara.assert_op import Assert
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer
from aesara.scalar.basic import Clip
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant

from aeppl.logprob import _logcdf, _logprob


# TODO: Add interval transform
class CensoredRV(RandomVariable):
r"""A base class for censored `RandomVariable`\s."""

def __init__(self, *args, base_op, **kwargs):
self.base_op = base_op
super().__init__(*args, **kwargs)


@_logprob.register(CensoredRV)
def censor_logprob(op, value, *inputs, name=None, **kwargs):

*rv_params, lower_bound, upper_bound = inputs
logprob = _logprob(op.base_op, value, *rv_params, name=name, **kwargs)
logcdf = _logcdf(op.base_op, value, *rv_params, name=name, **kwargs)
if op.base_op.name:
logprob.name = f"{op.base_op.name}_logprob"
logcdf.name = f"{op.base_op.name}_logcdf"

is_lower_bounded, is_upper_bounded = False, False
if not (
isinstance(upper_bound, TensorConstant) and np.all(upper_bound.value == np.inf)
):
is_upper_bounded = True
logprob = at.switch(
at.eq(value, upper_bound),
at.log(1 - at.exp(logcdf)),
at.switch(at.gt(value, upper_bound), -np.inf, logprob),
)
if not (
isinstance(lower_bound, TensorConstant) and np.all(lower_bound.value == -np.inf)
):
is_lower_bounded = True
logprob = at.switch(
at.eq(value, lower_bound),
logcdf,
at.switch(at.lt(value, lower_bound), -np.inf, logprob),
)

if is_lower_bounded and is_upper_bounded:
logprob = Assert("lower_bound <= upper_bound")(
logprob, at.all(at.le(lower_bound, upper_bound))
)

return logprob


@local_optimizer(tracks=[Elemwise])
def censor_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return

if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, Clip):
return

clipped_var = node.outputs[0]
if clipped_var not in rv_map_feature.rv_values:
return

base_var, lower_bound, upper_bound = node.inputs
base_op = base_var.owner.op

if not isinstance(base_op, RandomVariable):
return

if base_var in rv_map_feature.rv_values:
warnings.warn(
f"Value variables were assigned to both the input ({base_var}) and "
f"output ({clipped_var}) of a censored random variable."
)
return

censored_rv = CensoredRV(
"censored",
base_op.ndim_supp,
list(base_op.ndims_params) + [base_op.ndim_supp] * 2,
base_op.dtype,
inplace=False,
base_op=base_op,
)

# Replace bounds by +-inf if `y = clip(x, x, ?)` or `y=clip(x, ?, x)`
# This is used in `censor_logprob` to generate a more succint logprob graph
# for one-sided censored random variables
# TODO: This will probably fail for multivariate variables
lower_bound = lower_bound if (lower_bound is not base_var) else -np.inf
upper_bound = upper_bound if (upper_bound is not base_var) else np.inf

censored_node = censored_rv.make_node(
*base_var.owner.inputs,
lower_bound,
upper_bound,
)

censored_node_out = censored_node.outputs[1]

if not censored_node_out.type == clipped_var.type:
# TODO: issue warning?
return

if clipped_var.name:
censored_node_out.name = clipped_var.name
elif base_var.name:
censored_node_out.name = f"{base_var.name}_censored"

return [censored_node_out]
161 changes: 161 additions & 0 deletions tests/test_truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import aesara
import aesara.tensor as at
import numpy as np
import pytest
import scipy.stats as st

from aeppl import joint_logprob
from tests.utils import assert_no_rvs


def test_uniform_censoring():
x_rv = at.random.uniform(-4, 5)
cens_x_rv = at.clip(x_rv, -1, 1)
cens_x_rv.name = "cens_x_rv"

cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
assert_no_rvs(logp)

logp_fn = aesara.function([cens_x], logp)
ref_scipy = st.uniform(-4, 9)

assert logp_fn(-5) == -np.inf
assert logp_fn(6) == -np.inf

assert np.isclose(logp_fn(-1), ref_scipy.logcdf(-1))
assert np.isclose(logp_fn(5), ref_scipy.logsf(5))
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))


def test_normal_censoring():
x_rv = at.random.normal(0.5, 1, name="x_rv")
cens_x_rv = at.clip(x_rv, -2, 2)

cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
assert_no_rvs(logp)

logp_fn = aesara.function([cens_x], logp)
ref_scipy = st.norm(0.5, 1)

assert logp_fn(-3) == -np.inf
assert logp_fn(3) == -np.inf

assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2))
assert np.isclose(logp_fn(2), ref_scipy.logsf(2))
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))


def test_one_sided_censoring():
x_rv = at.random.normal(0, 1)
lb_cens_x_rv = at.clip(x_rv, -1, x_rv)
ub_cens_x_rv = at.clip(x_rv, x_rv, 1)

lb_cens_x = lb_cens_x_rv.type()
ub_cens_x = ub_cens_x_rv.type()

lb_logp = joint_logprob(lb_cens_x_rv, {lb_cens_x_rv: lb_cens_x})
ub_logp = joint_logprob(ub_cens_x_rv, {ub_cens_x_rv: ub_cens_x})
assert_no_rvs(lb_logp)
assert_no_rvs(ub_logp)

logp_fn = aesara.function([lb_cens_x, ub_cens_x], [lb_logp, ub_logp])
ref_scipy = st.norm(0, 1)

assert np.all(np.array(logp_fn(-2, 2)) == -np.inf)
assert np.all(np.array(logp_fn(2, -2)) != -np.inf)
np.testing.assert_almost_equal(logp_fn(-1, 1), ref_scipy.logcdf(-1))
np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1))


def test_useless_censoring():
x_rv = at.random.normal(0.5, 1, size=3)
cens_x_rv = at.clip(x_rv, x_rv, x_rv)

cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
assert_no_rvs(logp)

logp_fn = aesara.function([cens_x], logp)
ref_scipy = st.norm(0.5, 1)

np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2]))


def test_random_censoring():
lb_rv = at.random.normal(0, 1, size=2)
x_rv = at.random.normal(0, 2)
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])

lb = lb_rv.type()
cens_x = cens_x_rv.type()
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp)

logp_fn = aesara.function([lb, cens_x], logp)
res = logp_fn([0, -1], [-1, -1])
assert res[0] == -np.inf
assert res[1] != -np.inf


@pytest.mark.xfail(reason="Broadcasting not properly handled yet")
def test_broadcasted_censoring_constant():
lb_rv = at.random.uniform(0, 1, name="lb_rv")
x_rv = at.random.normal(0, 2, name="x_rv")
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])

lb = lb_rv.type()
lb.name = "lb"
cens_x = cens_x_rv.type()
cens_x.name = "cens_x"

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp)


@pytest.mark.xfail(reason="Broadcasting not properly handled yet")
def test_broadcasted_censoring_random():
lb_rv = at.random.normal(0, 1, name="lb_rv")
x_rv = at.random.normal(0, 2, size=2, name="x_rv")
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])

lb = lb_rv.type()
lb.name = "lb"
cens_x = cens_x_rv.type()
cens_x.name = "cens_x"

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp)


def test_failed_censoring():
# Test that `joint_logprob` fails when both base_rv and clipped_rv are given
# value vars
x_rv = at.random.normal(0, 1)
cens_x_rv = at.clip(x_rv, x_rv, 1)

x = x_rv.type()
cens_x = cens_x_rv.type()
with pytest.raises(NotImplementedError):
joint_logprob(cens_x_rv, {cens_x_rv: cens_x, x_rv: x})


def test_deterministic_clipping():
x_rv = at.random.normal(0, 1)
clip = at.clip(x_rv, 0, 0)
y_rv = at.random.normal(clip, 1)

x = x_rv.type()
y = y_rv.type()
logp = joint_logprob(y_rv, {x_rv: x, y_rv: y})
assert_no_rvs(logp)

logp_fn = aesara.function([x, y], logp)
assert np.isclose(
logp_fn(-1, 1),
st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1),
)

0 comments on commit 9dd9052

Please sign in to comment.