Skip to content

Commit

Permalink
Add one-sided censoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Jun 25, 2021
1 parent e16ef75 commit c5e7da1
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 28 deletions.
79 changes: 53 additions & 26 deletions aeppl/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aesara.scalar.basic import Clip
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant

from aeppl.logprob import _logcdf, _logprob

Expand All @@ -24,58 +25,81 @@ def __init__(self, *args, base_op, **kwargs):

@_logprob.register(CensoredRV)
def censor_logprob(op, value, *inputs, name=None, **kwargs):
# TODO: Exact bound check might be problematic

*rv_params, lower_bound, upper_bound = inputs
logprob = _logprob(op.base_op, value, *rv_params, name=name, **kwargs)
logcdf = _logcdf(op.base_op, value, *rv_params, name=name, **kwargs)

# TODO: Check constant -inf
# TODO: Exact bound check might be problematic
res = at.switch(
at.eq(value, lower_bound),
logcdf,
at.switch(
is_lower_bounded, is_upper_bounded = False, False
if not (
isinstance(upper_bound, TensorConstant) and np.all(upper_bound.value == np.inf)
):
is_upper_bounded = True

logprob = at.switch(
at.eq(value, upper_bound),
at.log(1 - at.exp(logcdf)),
logprob,
),
)
res = at.switch(
at.or_(at.lt(value, lower_bound), at.gt(value, upper_bound)),
-np.inf,
res,
)

res = Assert("lower_bound <= upper_bound")(
res, at.all(at.le(lower_bound, upper_bound))
)
at.switch(at.gt(value, upper_bound), -np.inf, logprob),
)

# Lower censored logp
if not (
isinstance(lower_bound, TensorConstant) and np.all(lower_bound.value == -np.inf)
):
is_lower_bounded = True
logprob = at.switch(
at.eq(value, lower_bound),
logcdf,
at.switch(at.lt(value, lower_bound), -np.inf, logprob),
)

if is_lower_bounded and is_upper_bounded:
logprob = Assert("lower_bound <= upper_bound")(
logprob, at.all(at.le(lower_bound, upper_bound))
)

if name:
res.name = f"{name}_censored_logprob"
logprob.name = f"{name}_censored_logprob"

return res
return logprob


@local_optimizer(tracks=[Elemwise])
def censor_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return

# TODO: Allow for one sided-truncation with set_subtensor graph (x[x>ub] = ub)
if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, Clip):
return

clipped_var = node.out
if clipped_var not in rv_map_feature.rv_values:
return

base_var, lower_bound, upper_bound = node.inputs
base_op = base_var.owner.op

if not isinstance(base_op, RandomVariable):
return

# clipped_var = node.ouputs[0]
# clipped_value = rv_map_feature.rv_values.pop(clipped_var)
# base_var var already has a direct value var
# TODO: Raise informative Error
if base_var in rv_map_feature.rv_values:
return

is_lower_bounded, is_upper_bounded = False, False
if lower_bound is not base_var:
# y = clip(x, x, ?)
is_lower_bounded = True

if upper_bound is not base_var:
# y = clip(x, ?, x)
is_upper_bounded = True

censored_rv = CensoredRV(
"censored",
Expand All @@ -85,11 +109,14 @@ def censor_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
inplace=False,
base_op=base_op,
)

censored_node = censored_rv.make_node(
*base_var.owner.inputs,
lower_bound,
upper_bound,
lower_bound if is_lower_bounded else -np.inf,
upper_bound if is_upper_bounded else np.inf,
)

# assert censored_node.outputs[1].type == clipped_var.type
if not censored_node.outputs[1].type == clipped_var.type:
return

return [censored_node.outputs[1]]
86 changes: 84 additions & 2 deletions tests/test_truncation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import aesara
import aesara.tensor as at
import numpy as np
import pytest
import scipy.stats as st

from aeppl import joint_logprob
from tests.utils import assert_no_rvs


def test_uniform_censoring():
x_rv = at.random.uniform(-4, 5)
cens_x_rv = at.clip(x_rv, -1, 1)

cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
logp_fn = aesara.function([cens_x], logp)
assert_no_rvs(logp)

logp_fn = aesara.function([cens_x], logp)
ref_scipy = st.uniform(-4, 9)

assert logp_fn(-5) == -np.inf
Expand All @@ -27,11 +31,13 @@ def test_uniform_censoring():
def test_normal_censoring():
x_rv = at.random.normal(0.5, 1)
cens_x_rv = at.clip(x_rv, -2, 2)

cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x})
logp_fn = aesara.function([cens_x], logp)
assert_no_rvs(logp)

logp_fn = aesara.function([cens_x], logp)
ref_scipy = st.norm(0.5, 1)

assert logp_fn(-3) == -np.inf
Expand All @@ -40,3 +46,79 @@ def test_normal_censoring():
assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2))
assert np.isclose(logp_fn(2), ref_scipy.logsf(2))
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0))


def test_one_sided_censoring():

x_rv = at.random.normal(0, 1)
lb_cens_x_rv = at.clip(x_rv, -1, x_rv)
ub_cens_x_rv = at.clip(x_rv, x_rv, 1)

lb_cens_x = lb_cens_x_rv.type()
ub_cens_x = ub_cens_x_rv.type()

lb_logp = joint_logprob(lb_cens_x_rv, {lb_cens_x_rv: lb_cens_x})
ub_logp = joint_logprob(ub_cens_x_rv, {ub_cens_x_rv: ub_cens_x})
assert_no_rvs(lb_logp)
assert_no_rvs(ub_logp)

logp_fn = aesara.function([lb_cens_x, ub_cens_x], [lb_logp, ub_logp])

assert np.all(np.array(logp_fn(-2, 2)) == -np.inf)
assert np.all(np.array(logp_fn(2, -2)) != -np.inf)


def test_failed_censoring():

x_rv = at.random.normal(0, 1)
cens_x_rv = at.clip(x_rv, x_rv, x_rv)

x = x_rv.type()
cens_x = cens_x_rv.type()
with pytest.raises(NotImplementedError):
joint_logprob(cens_x_rv, {cens_x_rv: cens_x, x_rv: x})


def test_random_censoring():

lb_rv = at.random.normal(0, 1, size=2)
x_rv = at.random.normal(0, 2)
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])

lb = lb_rv.type()
cens_x = cens_x_rv.type()
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp)

logp_fn = aesara.function([lb, cens_x], logp)
res = logp_fn([0, -1], [-1, -1])
assert res[0] == -np.inf
assert res[1] != -np.inf


@pytest.mark.xfail(reason="Broadcasting not properly handled yet")
def test_broadcasted_censoring():

lb_rv = at.random.normal(0, 1)
x_rv = at.random.normal(0, 2)
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])

lb = lb_rv.type()
cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp)


@pytest.mark.xfail(reason="Broadcasting not properly handled yet")
def test_broadcasted_censoring2():

lb_rv = at.random.normal(0, 1)
x_rv = at.random.normal(0, 2, size=2)
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])

lb = lb_rv.type()
cens_x = cens_x_rv.type()

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp)

0 comments on commit c5e7da1

Please sign in to comment.