-
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Ricardo
committed
Oct 27, 2021
1 parent
b98d51c
commit 6a16fb6
Showing
3 changed files
with
317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,5 +13,6 @@ | |
# Add optimizations to the DBs | ||
import aeppl.mixture | ||
import aeppl.scan | ||
import aeppl.truncation | ||
|
||
# isort: on |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import warnings | ||
from typing import List, Optional | ||
|
||
import aesara.tensor as at | ||
import numpy as np | ||
from aesara.assert_op import Assert | ||
from aesara.graph.basic import Node | ||
from aesara.graph.fg import FunctionGraph | ||
from aesara.graph.opt import local_optimizer | ||
from aesara.scalar.basic import Clip | ||
from aesara.scalar.basic import clip as scalar_clip | ||
from aesara.tensor.elemwise import Elemwise | ||
from aesara.tensor.var import TensorConstant | ||
|
||
from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs | ||
from aeppl.logprob import _logcdf, _logprob | ||
from aeppl.opt import rv_sinking_db | ||
|
||
|
||
class CensoredRV(Elemwise): | ||
"""A placeholder used to specify a log-likelihood for a censored RV sub-graph.""" | ||
|
||
|
||
MeasurableVariable.register(CensoredRV) | ||
|
||
|
||
@local_optimizer(tracks=[Elemwise]) | ||
def find_censored_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[CensoredRV]]: | ||
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) | ||
|
||
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) | ||
if rv_map_feature is None: | ||
return None # pragma: no cover | ||
|
||
if isinstance(node.op, CensoredRV): | ||
return None | ||
|
||
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)): | ||
return None | ||
|
||
clipped_var = node.outputs[0] | ||
if clipped_var not in rv_map_feature.rv_values: | ||
return None | ||
|
||
base_var, lower_bound, upper_bound = node.inputs | ||
|
||
if not (base_var.owner and isinstance(base_var.owner.op, MeasurableVariable)): | ||
return None | ||
|
||
if base_var in rv_map_feature.rv_values: | ||
warnings.warn( | ||
f"Value variables were assigned to both the input ({base_var}) and " | ||
f"output ({clipped_var}) of a censored random variable." | ||
) | ||
return None | ||
|
||
# Replace bounds by `+-inf` if `y = clip(x, x, ?)` or `y=clip(x, ?, x)` | ||
# This is used in `censor_logprob` to generate a more succint logprob graph | ||
# for one-sided censored random variables | ||
lower_bound = lower_bound if (lower_bound is not base_var) else at.constant(-np.inf) | ||
upper_bound = upper_bound if (upper_bound is not base_var) else at.constant(np.inf) | ||
|
||
censored_op = CensoredRV(scalar_clip) | ||
# Make base_var unmeasurable | ||
unmeasurable_base_var = assign_custom_measurable_outputs(base_var.owner) | ||
censored_rv_node = censored_op.make_node( | ||
unmeasurable_base_var, lower_bound, upper_bound | ||
) | ||
censored_rv = censored_rv_node.outputs[0] | ||
|
||
censored_rv.name = clipped_var.name | ||
|
||
return [censored_rv] | ||
|
||
|
||
rv_sinking_db.register("find_censored_rvs", find_censored_rvs, -5, "basic") | ||
|
||
|
||
@_logprob.register(CensoredRV) | ||
def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs): | ||
(value,) = values | ||
|
||
base_rv_op = base_rv.owner.op | ||
base_rv_inputs = base_rv.owner.inputs | ||
|
||
logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs) | ||
logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs) | ||
|
||
if base_rv_op.name: | ||
logprob.name = f"{base_rv_op}_logprob" | ||
logcdf.name = f"{base_rv_op}_logcdf" | ||
|
||
is_lower_bounded, is_upper_bounded = False, False | ||
if not ( | ||
isinstance(upper_bound, TensorConstant) and np.all(upper_bound.value == np.inf) | ||
): | ||
is_upper_bounded = True | ||
|
||
logccdf = at.log(1 - at.exp(logcdf)) | ||
# For right censored discrete RVs, we need to add an extra term | ||
# corresponding to the pmf at the upper bound | ||
if base_rv_op.dtype == "int64": | ||
logccdf = at.logaddexp(logccdf, logprob) | ||
|
||
logprob = at.switch( | ||
at.eq(value, upper_bound), | ||
logccdf, | ||
at.switch(at.gt(value, upper_bound), -np.inf, logprob), | ||
) | ||
if not ( | ||
isinstance(lower_bound, TensorConstant) and np.all(lower_bound.value == -np.inf) | ||
): | ||
is_lower_bounded = True | ||
logprob = at.switch( | ||
at.eq(value, lower_bound), | ||
logcdf, | ||
at.switch(at.lt(value, lower_bound), -np.inf, logprob), | ||
) | ||
|
||
if is_lower_bounded and is_upper_bounded: | ||
logprob = Assert("lower_bound <= upper_bound")( | ||
logprob, at.all(at.le(lower_bound, upper_bound)) | ||
) | ||
|
||
return logprob |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import aesara | ||
import aesara.tensor as at | ||
import numpy as np | ||
import pytest | ||
import scipy as sp | ||
import scipy.stats as st | ||
|
||
from aeppl import joint_logprob | ||
from aeppl.transforms import LogTransform, TransformValuesOpt | ||
from tests.utils import assert_no_rvs | ||
|
||
|
||
@aesara.config.change_flags(compute_test_value="raise") | ||
def test_continuous_rv_censoring(): | ||
x_rv = at.random.normal(0.5, 1) | ||
cens_x_rv = at.clip(x_rv, -2, 2) | ||
|
||
cens_x_vv = cens_x_rv.clone() | ||
cens_x_vv.tag.test_value = 0 | ||
|
||
logp = joint_logprob({cens_x_rv: cens_x_vv}) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([cens_x_vv], logp) | ||
ref_scipy = st.norm(0.5, 1) | ||
|
||
assert logp_fn(-3) == -np.inf | ||
assert logp_fn(3) == -np.inf | ||
|
||
assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2)) | ||
assert np.isclose(logp_fn(2), ref_scipy.logsf(2)) | ||
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) | ||
|
||
|
||
def test_discrete_rv_censoring(): | ||
x_rv = at.random.poisson(2) | ||
cens_x_rv = at.clip(x_rv, 1, 4) | ||
|
||
cens_x_vv = cens_x_rv.clone() | ||
|
||
logp = joint_logprob({cens_x_rv: cens_x_vv}) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([cens_x_vv], logp) | ||
ref_scipy = st.poisson(2) | ||
|
||
assert logp_fn(0) == -np.inf | ||
assert logp_fn(5) == -np.inf | ||
|
||
assert np.isclose(logp_fn(1), ref_scipy.logcdf(1)) | ||
assert np.isclose(logp_fn(4), np.logaddexp(ref_scipy.logsf(4), ref_scipy.logpmf(4))) | ||
assert np.isclose(logp_fn(2), ref_scipy.logpmf(2)) | ||
|
||
|
||
def test_one_sided_censoring(): | ||
x_rv = at.random.normal(0, 1) | ||
lb_cens_x_rv = at.clip(x_rv, -1, x_rv) | ||
ub_cens_x_rv = at.clip(x_rv, x_rv, 1) | ||
|
||
lb_cens_x_vv = lb_cens_x_rv.clone() | ||
ub_cens_x_vv = ub_cens_x_rv.clone() | ||
|
||
lb_logp = joint_logprob({lb_cens_x_rv: lb_cens_x_vv}) | ||
ub_logp = joint_logprob({ub_cens_x_rv: ub_cens_x_vv}) | ||
assert_no_rvs(lb_logp) | ||
assert_no_rvs(ub_logp) | ||
|
||
logp_fn = aesara.function([lb_cens_x_vv, ub_cens_x_vv], [lb_logp, ub_logp]) | ||
ref_scipy = st.norm(0, 1) | ||
|
||
assert np.all(np.array(logp_fn(-2, 2)) == -np.inf) | ||
assert np.all(np.array(logp_fn(2, -2)) != -np.inf) | ||
np.testing.assert_almost_equal(logp_fn(-1, 1), ref_scipy.logcdf(-1)) | ||
np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1)) | ||
|
||
|
||
def test_useless_censoring(): | ||
x_rv = at.random.normal(0.5, 1, size=3) | ||
cens_x_rv = at.clip(x_rv, x_rv, x_rv) | ||
|
||
cens_x_vv = cens_x_rv.clone() | ||
|
||
logp = joint_logprob({cens_x_rv: cens_x_vv}, sum=False) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([cens_x_vv], logp) | ||
ref_scipy = st.norm(0.5, 1) | ||
|
||
np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2])) | ||
|
||
|
||
def test_random_censoring(): | ||
lb_rv = at.random.normal(0, 1, size=2) | ||
x_rv = at.random.normal(0, 2) | ||
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1]) | ||
|
||
lb_vv = lb_rv.clone() | ||
cens_x_vv = cens_x_rv.clone() | ||
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}, sum=False) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([lb_vv, cens_x_vv], logp) | ||
res = logp_fn([0, -1], [-1, -1]) | ||
assert res[0] == -np.inf | ||
assert res[1] != -np.inf | ||
|
||
|
||
def test_broadcasted_censoring_constant(): | ||
lb_rv = at.random.uniform(0, 1) | ||
x_rv = at.random.normal(0, 2) | ||
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1]) | ||
|
||
lb_vv = lb_rv.clone() | ||
cens_x_vv = cens_x_rv.clone() | ||
|
||
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}) | ||
assert_no_rvs(logp) | ||
|
||
|
||
def test_broadcasted_censoring_random(): | ||
lb_rv = at.random.normal(0, 1) | ||
x_rv = at.random.normal(0, 2, size=2) | ||
cens_x_rv = at.clip(x_rv, lb_rv, 1) | ||
|
||
lb_vv = lb_rv.clone() | ||
cens_x_vv = cens_x_rv.clone() | ||
|
||
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}) | ||
assert_no_rvs(logp) | ||
|
||
|
||
def test_fail_base_and_censored_have_values(): | ||
"""Test failure when both base_rv and clipped_rv are given value vars""" | ||
x_rv = at.random.normal(0, 1) | ||
cens_x_rv = at.clip(x_rv, x_rv, 1) | ||
cens_x_rv.name = "cens_x" | ||
|
||
x_vv = x_rv.clone() | ||
cens_x_vv = cens_x_rv.clone() | ||
with pytest.raises(RuntimeError, match="{cens_x}"): | ||
joint_logprob({cens_x_rv: cens_x_vv, x_rv: x_vv}) | ||
|
||
|
||
def test_fail_multiple_censored_single_base(): | ||
"""Test failure when multiple clipped_rvs share a single base_rv""" | ||
base_rv = at.random.normal(0, 1) | ||
cens_rv1 = at.clip(base_rv, -1, 1) | ||
cens_rv1.name = "cens1" | ||
cens_rv2 = at.clip(base_rv, -1, 1) | ||
cens_rv2.name = "cens2" | ||
|
||
cens_vv1 = cens_rv1.clone() | ||
cens_vv2 = cens_rv2.clone() | ||
with pytest.raises(RuntimeError, match="{cens2}"): | ||
joint_logprob({cens_rv1: cens_vv1, cens_rv2: cens_vv2}) | ||
|
||
|
||
def test_deterministic_clipping(): | ||
x_rv = at.random.normal(0, 1) | ||
clip = at.clip(x_rv, 0, 0) | ||
y_rv = at.random.normal(clip, 1) | ||
|
||
x_vv = x_rv.clone() | ||
y_vv = y_rv.clone() | ||
logp = joint_logprob({x_rv: x_vv, y_rv: y_vv}) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([x_vv, y_vv], logp) | ||
assert np.isclose( | ||
logp_fn(-1, 1), | ||
st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1), | ||
) | ||
|
||
|
||
@pytest.mark.xfail(reason="Transform does not work with Elemwise ops, see #60") | ||
def test_censored_transform(): | ||
x_rv = at.random.normal(0.5, 1) | ||
cens_x_rv = at.clip(x_rv, 0, x_rv) | ||
|
||
cens_x_vv = cens_x_rv.clone() | ||
|
||
transform = TransformValuesOpt({cens_x_vv: LogTransform()}) | ||
logp = joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform) | ||
|
||
cens_x_vv_testval = -1 | ||
obs_logp = logp.eval({cens_x_vv: cens_x_vv_testval}) | ||
exp_logp = ( | ||
sp.stats.norm(0.5, 1).logpdf(np.exp(cens_x_vv_testval)) + cens_x_vv_testval | ||
) | ||
|
||
assert np.isclose(obs_logp, exp_logp) |