Skip to content

Commit

Permalink
Implement censored RVs logprob
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Nov 2, 2021
1 parent 9e9d530 commit d7745ac
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 0 deletions.
1 change: 1 addition & 0 deletions aeppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
# Add optimizations to the DBs
import aeppl.mixture
import aeppl.scan
import aeppl.truncation

# isort: on
125 changes: 125 additions & 0 deletions aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import warnings
from typing import List, Optional

import aesara.tensor as at
import numpy as np
from aesara.assert_op import Assert
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer
from aesara.scalar.basic import Clip
from aesara.scalar.basic import clip as scalar_clip
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.var import TensorConstant

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.logprob import _logcdf, _logprob
from aeppl.opt import rv_sinking_db


class CensoredRV(Elemwise):
"""A placeholder used to specify a log-likelihood for a censored RV sub-graph."""


MeasurableVariable.register(CensoredRV)


@local_optimizer(tracks=[Elemwise])
def find_censored_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[CensoredRV]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, CensoredRV):
return None # pragma: no cover

if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)):
return None

clipped_var = node.outputs[0]
if clipped_var not in rv_map_feature.rv_values:
return None

base_var, lower_bound, upper_bound = node.inputs

if not (base_var.owner and isinstance(base_var.owner.op, MeasurableVariable)):
return None

if base_var in rv_map_feature.rv_values:
warnings.warn(
f"Value variables were assigned to both the input ({base_var}) and "
f"output ({clipped_var}) of a censored random variable."
)
return None

# Replace bounds by `+-inf` if `y = clip(x, x, ?)` or `y=clip(x, ?, x)`
# This is used in `censor_logprob` to generate a more succint logprob graph
# for one-sided censored random variables
lower_bound = lower_bound if (lower_bound is not base_var) else at.constant(-np.inf)
upper_bound = upper_bound if (upper_bound is not base_var) else at.constant(np.inf)

censored_op = CensoredRV(scalar_clip)
# Make base_var unmeasurable
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner)
censored_rv_node = censored_op.make_node(
unmeasurable_base_var, lower_bound, upper_bound
)
censored_rv = censored_rv_node.outputs[0]

censored_rv.name = clipped_var.name

return [censored_rv]


rv_sinking_db.register("find_censored_rvs", find_censored_rvs, -5, "basic")


@_logprob.register(CensoredRV)
def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
(value,) = values

base_rv_op = base_rv.owner.op
base_rv_inputs = base_rv.owner.inputs

logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs)
logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs)

if base_rv_op.name:
logprob.name = f"{base_rv_op}_logprob"
logcdf.name = f"{base_rv_op}_logcdf"

is_lower_bounded, is_upper_bounded = False, False
if not (
isinstance(upper_bound, TensorConstant) and np.all(upper_bound.value == np.inf)
):
is_upper_bounded = True

logccdf = at.log(1 - at.exp(logcdf))
# For right censored discrete RVs, we need to add an extra term
# corresponding to the pmf at the upper bound
if base_rv_op.dtype == "int64":
logccdf = at.logaddexp(logccdf, logprob)

logprob = at.switch(
at.eq(value, upper_bound),
logccdf,
at.switch(at.gt(value, upper_bound), -np.inf, logprob),
)
if not (
isinstance(lower_bound, TensorConstant) and np.all(lower_bound.value == -np.inf)
):
is_lower_bounded = True
logprob = at.switch(
at.eq(value, lower_bound),
logcdf,
at.switch(at.lt(value, lower_bound), -np.inf, logprob),
)

if is_lower_bounded and is_upper_bounded:
logprob = Assert("lower_bound <= upper_bound")(
logprob, at.all(at.le(lower_bound, upper_bound))
)

return logprob
191 changes: 191 additions & 0 deletions tests/test_truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import aesara
import aesara.tensor as at
import numpy as np
import pytest
import scipy as sp
import scipy.stats as st

from aeppl import joint_logprob
from aeppl.transforms import LogTransform, TransformValuesOpt
from tests.utils import assert_no_rvs


@aesara.config.change_flags(compute_test_value="raise")
def test_continuous_rv_censoring():
x_rv = at.random.normal(0.5, 1)
cens_x_rv = at.clip(x_rv, -2, 2)

cens_x_vv = cens_x_rv.clone()
cens_x_vv.tag.test_value = 0

logp = joint_logprob({cens_x_rv: cens_x_vv})
assert_no_rvs(logp)

logp_fn = aesara.function([cens_x_vv], logp)
ref_scipy = st.norm(0.5, 1)

assert logp_fn(-3) == -np.inf
assert logp_fn(3) == -np.inf

assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2))
assert np.isclose(logp_fn(2), ref_scipy.logsf(2))
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))


def test_discrete_rv_censoring():
x_rv = at.random.poisson(2)
cens_x_rv = at.clip(x_rv, 1, 4)

cens_x_vv = cens_x_rv.clone()

logp = joint_logprob({cens_x_rv: cens_x_vv})
assert_no_rvs(logp)

logp_fn = aesara.function([cens_x_vv], logp)
ref_scipy = st.poisson(2)

assert logp_fn(0) == -np.inf
assert logp_fn(5) == -np.inf

assert np.isclose(logp_fn(1), ref_scipy.logcdf(1))
assert np.isclose(logp_fn(4), np.logaddexp(ref_scipy.logsf(4), ref_scipy.logpmf(4)))
assert np.isclose(logp_fn(2), ref_scipy.logpmf(2))


def test_one_sided_censoring():
x_rv = at.random.normal(0, 1)
lb_cens_x_rv = at.clip(x_rv, -1, x_rv)
ub_cens_x_rv = at.clip(x_rv, x_rv, 1)

lb_cens_x_vv = lb_cens_x_rv.clone()
ub_cens_x_vv = ub_cens_x_rv.clone()

lb_logp = joint_logprob({lb_cens_x_rv: lb_cens_x_vv})
ub_logp = joint_logprob({ub_cens_x_rv: ub_cens_x_vv})
assert_no_rvs(lb_logp)
assert_no_rvs(ub_logp)

logp_fn = aesara.function([lb_cens_x_vv, ub_cens_x_vv], [lb_logp, ub_logp])
ref_scipy = st.norm(0, 1)

assert np.all(np.array(logp_fn(-2, 2)) == -np.inf)
assert np.all(np.array(logp_fn(2, -2)) != -np.inf)
np.testing.assert_almost_equal(logp_fn(-1, 1), ref_scipy.logcdf(-1))
np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1))


def test_useless_censoring():
x_rv = at.random.normal(0.5, 1, size=3)
cens_x_rv = at.clip(x_rv, x_rv, x_rv)

cens_x_vv = cens_x_rv.clone()

logp = joint_logprob({cens_x_rv: cens_x_vv}, sum=False)
assert_no_rvs(logp)

logp_fn = aesara.function([cens_x_vv], logp)
ref_scipy = st.norm(0.5, 1)

np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2]))


def test_random_censoring():
lb_rv = at.random.normal(0, 1, size=2)
x_rv = at.random.normal(0, 2)
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])

lb_vv = lb_rv.clone()
cens_x_vv = cens_x_rv.clone()
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}, sum=False)
assert_no_rvs(logp)

logp_fn = aesara.function([lb_vv, cens_x_vv], logp)
res = logp_fn([0, -1], [-1, -1])
assert res[0] == -np.inf
assert res[1] != -np.inf


def test_broadcasted_censoring_constant():
lb_rv = at.random.uniform(0, 1)
x_rv = at.random.normal(0, 2)
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])

lb_vv = lb_rv.clone()
cens_x_vv = cens_x_rv.clone()

logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
assert_no_rvs(logp)


def test_broadcasted_censoring_random():
lb_rv = at.random.normal(0, 1)
x_rv = at.random.normal(0, 2, size=2)
cens_x_rv = at.clip(x_rv, lb_rv, 1)

lb_vv = lb_rv.clone()
cens_x_vv = cens_x_rv.clone()

logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
assert_no_rvs(logp)


def test_fail_base_and_censored_have_values():
"""Test failure when both base_rv and clipped_rv are given value vars"""
x_rv = at.random.normal(0, 1)
cens_x_rv = at.clip(x_rv, x_rv, 1)
cens_x_rv.name = "cens_x"

x_vv = x_rv.clone()
cens_x_vv = cens_x_rv.clone()
with pytest.raises(RuntimeError, match="{cens_x}"):
joint_logprob({cens_x_rv: cens_x_vv, x_rv: x_vv})


def test_fail_multiple_censored_single_base():
"""Test failure when multiple clipped_rvs share a single base_rv"""
base_rv = at.random.normal(0, 1)
cens_rv1 = at.clip(base_rv, -1, 1)
cens_rv1.name = "cens1"
cens_rv2 = at.clip(base_rv, -1, 1)
cens_rv2.name = "cens2"

cens_vv1 = cens_rv1.clone()
cens_vv2 = cens_rv2.clone()
with pytest.raises(RuntimeError, match="{cens2}"):
joint_logprob({cens_rv1: cens_vv1, cens_rv2: cens_vv2})


def test_deterministic_clipping():
x_rv = at.random.normal(0, 1)
clip = at.clip(x_rv, 0, 0)
y_rv = at.random.normal(clip, 1)

x_vv = x_rv.clone()
y_vv = y_rv.clone()
logp = joint_logprob({x_rv: x_vv, y_rv: y_vv})
assert_no_rvs(logp)

logp_fn = aesara.function([x_vv, y_vv], logp)
assert np.isclose(
logp_fn(-1, 1),
st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1),
)


@pytest.mark.xfail(reason="Transform does not work with Elemwise ops, see #60")
def test_censored_transform():
x_rv = at.random.normal(0.5, 1)
cens_x_rv = at.clip(x_rv, 0, x_rv)

cens_x_vv = cens_x_rv.clone()

transform = TransformValuesOpt({cens_x_vv: LogTransform()})
logp = joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform)

cens_x_vv_testval = -1
obs_logp = logp.eval({cens_x_vv: cens_x_vv_testval})
exp_logp = (
sp.stats.norm(0.5, 1).logpdf(np.exp(cens_x_vv_testval)) + cens_x_vv_testval
)

assert np.isclose(obs_logp, exp_logp)

0 comments on commit d7745ac

Please sign in to comment.