-
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement RV censoring logprob and opt
- Loading branch information
Ricardo
committed
Jun 29, 2021
1 parent
e3930a2
commit 9dd9052
Showing
3 changed files
with
291 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import warnings | ||
from typing import List, Optional | ||
|
||
import aesara.tensor as at | ||
import numpy as np | ||
from aesara.assert_op import Assert | ||
from aesara.graph.basic import Node | ||
from aesara.graph.fg import FunctionGraph | ||
from aesara.graph.opt import local_optimizer | ||
from aesara.scalar.basic import Clip | ||
from aesara.tensor.elemwise import Elemwise | ||
from aesara.tensor.random.op import RandomVariable | ||
from aesara.tensor.var import TensorConstant | ||
|
||
from aeppl.logprob import _logcdf, _logprob | ||
|
||
|
||
# TODO: Add interval transform | ||
class CensoredRV(RandomVariable): | ||
r"""A base class for censored `RandomVariable`\s.""" | ||
|
||
def __init__(self, *args, base_op, **kwargs): | ||
self.base_op = base_op | ||
super().__init__(*args, **kwargs) | ||
|
||
|
||
@_logprob.register(CensoredRV) | ||
def censor_logprob(op, value, *inputs, name=None, **kwargs): | ||
|
||
*rv_params, lower_bound, upper_bound = inputs | ||
logprob = _logprob(op.base_op, value, *rv_params, name=name, **kwargs) | ||
logcdf = _logcdf(op.base_op, value, *rv_params, name=name, **kwargs) | ||
if op.base_op.name: | ||
logprob.name = f"{op.base_op.name}_logprob" | ||
logcdf.name = f"{op.base_op.name}_logcdf" | ||
|
||
is_lower_bounded, is_upper_bounded = False, False | ||
if not ( | ||
isinstance(upper_bound, TensorConstant) and np.all(upper_bound.value == np.inf) | ||
): | ||
is_upper_bounded = True | ||
logprob = at.switch( | ||
at.eq(value, upper_bound), | ||
at.log(1 - at.exp(logcdf)), | ||
at.switch(at.gt(value, upper_bound), -np.inf, logprob), | ||
) | ||
if not ( | ||
isinstance(lower_bound, TensorConstant) and np.all(lower_bound.value == -np.inf) | ||
): | ||
is_lower_bounded = True | ||
logprob = at.switch( | ||
at.eq(value, lower_bound), | ||
logcdf, | ||
at.switch(at.lt(value, lower_bound), -np.inf, logprob), | ||
) | ||
|
||
if is_lower_bounded and is_upper_bounded: | ||
logprob = Assert("lower_bound <= upper_bound")( | ||
logprob, at.all(at.le(lower_bound, upper_bound)) | ||
) | ||
|
||
return logprob | ||
|
||
|
||
@local_optimizer(tracks=[Elemwise]) | ||
def censor_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: | ||
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) | ||
|
||
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) | ||
if rv_map_feature is None: | ||
return | ||
|
||
if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, Clip): | ||
return | ||
|
||
clipped_var = node.outputs[0] | ||
if clipped_var not in rv_map_feature.rv_values: | ||
return | ||
|
||
base_var, lower_bound, upper_bound = node.inputs | ||
base_op = base_var.owner.op | ||
|
||
if not isinstance(base_op, RandomVariable): | ||
return | ||
|
||
if base_var in rv_map_feature.rv_values: | ||
warnings.warn( | ||
f"Value variables were assigned to both the input ({base_var}) and " | ||
f"output ({clipped_var}) of a censored random variable." | ||
) | ||
return | ||
|
||
censored_rv = CensoredRV( | ||
"censored", | ||
base_op.ndim_supp, | ||
list(base_op.ndims_params) + [base_op.ndim_supp] * 2, | ||
base_op.dtype, | ||
inplace=False, | ||
base_op=base_op, | ||
) | ||
|
||
# Replace bounds by +-inf if `y = clip(x, x, ?)` or `y=clip(x, ?, x)` | ||
# This is used in `censor_logprob` to generate a more succint logprob graph | ||
# for one-sided censored random variables | ||
# TODO: This will probably fail for multivariate variables | ||
lower_bound = lower_bound if (lower_bound is not base_var) else -np.inf | ||
upper_bound = upper_bound if (upper_bound is not base_var) else np.inf | ||
|
||
censored_node = censored_rv.make_node( | ||
*base_var.owner.inputs, | ||
lower_bound, | ||
upper_bound, | ||
) | ||
|
||
censored_node_out = censored_node.outputs[1] | ||
|
||
if not censored_node_out.type == clipped_var.type: | ||
# TODO: issue warning? | ||
return | ||
|
||
if clipped_var.name: | ||
censored_node_out.name = clipped_var.name | ||
elif base_var.name: | ||
censored_node_out.name = f"{base_var.name}_censored" | ||
|
||
return [censored_node_out] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import aesara | ||
import aesara.tensor as at | ||
import numpy as np | ||
import pytest | ||
import scipy.stats as st | ||
|
||
from aeppl import joint_logprob | ||
from tests.utils import assert_no_rvs | ||
|
||
|
||
def test_uniform_censoring(): | ||
x_rv = at.random.uniform(-4, 5) | ||
cens_x_rv = at.clip(x_rv, -1, 1) | ||
cens_x_rv.name = "cens_x_rv" | ||
|
||
cens_x = cens_x_rv.type() | ||
|
||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x}) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([cens_x], logp) | ||
ref_scipy = st.uniform(-4, 9) | ||
|
||
assert logp_fn(-5) == -np.inf | ||
assert logp_fn(6) == -np.inf | ||
|
||
assert np.isclose(logp_fn(-1), ref_scipy.logcdf(-1)) | ||
assert np.isclose(logp_fn(5), ref_scipy.logsf(5)) | ||
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) | ||
|
||
|
||
def test_normal_censoring(): | ||
x_rv = at.random.normal(0.5, 1, name="x_rv") | ||
cens_x_rv = at.clip(x_rv, -2, 2) | ||
|
||
cens_x = cens_x_rv.type() | ||
|
||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x}) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([cens_x], logp) | ||
ref_scipy = st.norm(0.5, 1) | ||
|
||
assert logp_fn(-3) == -np.inf | ||
assert logp_fn(3) == -np.inf | ||
|
||
assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2)) | ||
assert np.isclose(logp_fn(2), ref_scipy.logsf(2)) | ||
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) | ||
|
||
|
||
def test_one_sided_censoring(): | ||
x_rv = at.random.normal(0, 1) | ||
lb_cens_x_rv = at.clip(x_rv, -1, x_rv) | ||
ub_cens_x_rv = at.clip(x_rv, x_rv, 1) | ||
|
||
lb_cens_x = lb_cens_x_rv.type() | ||
ub_cens_x = ub_cens_x_rv.type() | ||
|
||
lb_logp = joint_logprob(lb_cens_x_rv, {lb_cens_x_rv: lb_cens_x}) | ||
ub_logp = joint_logprob(ub_cens_x_rv, {ub_cens_x_rv: ub_cens_x}) | ||
assert_no_rvs(lb_logp) | ||
assert_no_rvs(ub_logp) | ||
|
||
logp_fn = aesara.function([lb_cens_x, ub_cens_x], [lb_logp, ub_logp]) | ||
ref_scipy = st.norm(0, 1) | ||
|
||
assert np.all(np.array(logp_fn(-2, 2)) == -np.inf) | ||
assert np.all(np.array(logp_fn(2, -2)) != -np.inf) | ||
np.testing.assert_almost_equal(logp_fn(-1, 1), ref_scipy.logcdf(-1)) | ||
np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1)) | ||
|
||
|
||
def test_useless_censoring(): | ||
x_rv = at.random.normal(0.5, 1, size=3) | ||
cens_x_rv = at.clip(x_rv, x_rv, x_rv) | ||
|
||
cens_x = cens_x_rv.type() | ||
|
||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x}) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([cens_x], logp) | ||
ref_scipy = st.norm(0.5, 1) | ||
|
||
np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2])) | ||
|
||
|
||
def test_random_censoring(): | ||
lb_rv = at.random.normal(0, 1, size=2) | ||
x_rv = at.random.normal(0, 2) | ||
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1]) | ||
|
||
lb = lb_rv.type() | ||
cens_x = cens_x_rv.type() | ||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb}) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([lb, cens_x], logp) | ||
res = logp_fn([0, -1], [-1, -1]) | ||
assert res[0] == -np.inf | ||
assert res[1] != -np.inf | ||
|
||
|
||
@pytest.mark.xfail(reason="Broadcasting not properly handled yet") | ||
def test_broadcasted_censoring_constant(): | ||
lb_rv = at.random.uniform(0, 1, name="lb_rv") | ||
x_rv = at.random.normal(0, 2, name="x_rv") | ||
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1]) | ||
|
||
lb = lb_rv.type() | ||
lb.name = "lb" | ||
cens_x = cens_x_rv.type() | ||
cens_x.name = "cens_x" | ||
|
||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb}) | ||
assert_no_rvs(logp) | ||
|
||
|
||
@pytest.mark.xfail(reason="Broadcasting not properly handled yet") | ||
def test_broadcasted_censoring_random(): | ||
lb_rv = at.random.normal(0, 1, name="lb_rv") | ||
x_rv = at.random.normal(0, 2, size=2, name="x_rv") | ||
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1]) | ||
|
||
lb = lb_rv.type() | ||
lb.name = "lb" | ||
cens_x = cens_x_rv.type() | ||
cens_x.name = "cens_x" | ||
|
||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb}) | ||
assert_no_rvs(logp) | ||
|
||
|
||
def test_failed_censoring(): | ||
# Test that `joint_logprob` fails when both base_rv and clipped_rv are given | ||
# value vars | ||
x_rv = at.random.normal(0, 1) | ||
cens_x_rv = at.clip(x_rv, x_rv, 1) | ||
|
||
x = x_rv.type() | ||
cens_x = cens_x_rv.type() | ||
with pytest.raises(NotImplementedError): | ||
joint_logprob(cens_x_rv, {cens_x_rv: cens_x, x_rv: x}) | ||
|
||
|
||
def test_deterministic_clipping(): | ||
x_rv = at.random.normal(0, 1) | ||
clip = at.clip(x_rv, 0, 0) | ||
y_rv = at.random.normal(clip, 1) | ||
|
||
x = x_rv.type() | ||
y = y_rv.type() | ||
logp = joint_logprob(y_rv, {x_rv: x, y_rv: y}) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([x, y], logp) | ||
assert np.isclose( | ||
logp_fn(-1, 1), | ||
st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1), | ||
) |