-
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement RV censoring logprob and opt
- Loading branch information
Ricardo
committed
Jun 29, 2021
1 parent
e3930a2
commit ca27863
Showing
3 changed files
with
265 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from typing import List, Optional | ||
|
||
import aesara.tensor as at | ||
import numpy as np | ||
from aesara.assert_op import Assert | ||
from aesara.graph.basic import Node | ||
from aesara.graph.fg import FunctionGraph | ||
from aesara.graph.opt import local_optimizer | ||
from aesara.scalar.basic import Clip | ||
from aesara.tensor.elemwise import Elemwise | ||
from aesara.tensor.random.op import RandomVariable | ||
from aesara.tensor.var import TensorConstant | ||
|
||
from aeppl.logprob import _logcdf, _logprob | ||
|
||
|
||
# TODO: Add interval transform | ||
class CensoredRV(RandomVariable): | ||
r"""A base class for censored `RandomVariable`\s.""" | ||
|
||
def __init__(self, *args, base_op, **kwargs): | ||
self.base_op = base_op | ||
super().__init__(*args, **kwargs) | ||
|
||
|
||
@_logprob.register(CensoredRV) | ||
def censor_logprob(op, value, *inputs, name=None, **kwargs): | ||
# TODO: Exact bound check might be problematic | ||
|
||
*rv_params, lower_bound, upper_bound = inputs | ||
logprob = _logprob(op.base_op, value, *rv_params, name=name, **kwargs) | ||
logcdf = _logcdf(op.base_op, value, *rv_params, name=name, **kwargs) | ||
|
||
is_lower_bounded, is_upper_bounded = False, False | ||
if not ( | ||
isinstance(upper_bound, TensorConstant) and np.all(upper_bound.value == np.inf) | ||
): | ||
is_upper_bounded = True | ||
logprob = at.switch( | ||
at.eq(value, upper_bound), | ||
at.log(1 - at.exp(logcdf)), | ||
at.switch(at.gt(value, upper_bound), -np.inf, logprob), | ||
) | ||
|
||
# Lower censored logp | ||
if not ( | ||
isinstance(lower_bound, TensorConstant) and np.all(lower_bound.value == -np.inf) | ||
): | ||
is_lower_bounded = True | ||
logprob = at.switch( | ||
at.eq(value, lower_bound), | ||
logcdf, | ||
at.switch(at.lt(value, lower_bound), -np.inf, logprob), | ||
) | ||
|
||
if is_lower_bounded and is_upper_bounded: | ||
logprob = Assert("lower_bound <= upper_bound")( | ||
logprob, at.all(at.le(lower_bound, upper_bound)) | ||
) | ||
|
||
if name: | ||
logprob.name = f"{name}_censored_logprob" | ||
|
||
return logprob | ||
|
||
|
||
@local_optimizer(tracks=[Elemwise]) | ||
def censor_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: | ||
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) | ||
|
||
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) | ||
|
||
if rv_map_feature is None: | ||
return | ||
|
||
if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, Clip): | ||
return | ||
|
||
clipped_var = node.outputs[0] | ||
if clipped_var not in rv_map_feature.rv_values: | ||
return | ||
|
||
base_var, lower_bound, upper_bound = node.inputs | ||
base_op = base_var.owner.op | ||
|
||
if not isinstance(base_op, RandomVariable): | ||
return | ||
|
||
# base_var var already has a direct value var | ||
# TODO: Raise informative Error | ||
if base_var in rv_map_feature.rv_values: | ||
return | ||
|
||
is_lower_bounded, is_upper_bounded = False, False | ||
if lower_bound is not base_var: | ||
# y = clip(x, x, ?) | ||
is_lower_bounded = True | ||
|
||
if upper_bound is not base_var: | ||
# y = clip(x, ?, x) | ||
is_upper_bounded = True | ||
|
||
censored_rv = CensoredRV( | ||
"censored", | ||
base_op.ndim_supp, | ||
list(base_op.ndims_params) + [0, 0], | ||
base_op.dtype, | ||
inplace=False, | ||
base_op=base_op, | ||
) | ||
|
||
lower_bound = lower_bound if is_lower_bounded else -np.inf | ||
upper_bound = upper_bound if is_upper_bounded else np.inf | ||
# _, lower_bound, upper_bound = at.broadcast_arrays(clipped_var, lower_bound, upper_bound) | ||
|
||
censored_node = censored_rv.make_node( | ||
*base_var.owner.inputs, | ||
lower_bound, | ||
upper_bound, | ||
) | ||
|
||
censored_node_out = censored_node.outputs[1] | ||
censored_node_out.name = clipped_var.name | ||
|
||
if not censored_node_out.type == clipped_var.type: | ||
return | ||
|
||
return [censored_node_out] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import aesara | ||
import aesara.tensor as at | ||
import numpy as np | ||
import pytest | ||
import scipy.stats as st | ||
|
||
from aeppl import joint_logprob | ||
from tests.utils import assert_no_rvs | ||
|
||
|
||
def test_uniform_censoring(): | ||
x_rv = at.random.uniform(-4, 5) | ||
cens_x_rv = at.clip(x_rv, -1, 1) | ||
|
||
cens_x = cens_x_rv.type() | ||
|
||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x}) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([cens_x], logp) | ||
ref_scipy = st.uniform(-4, 9) | ||
|
||
assert logp_fn(-5) == -np.inf | ||
assert logp_fn(6) == -np.inf | ||
|
||
assert np.isclose(logp_fn(-1), ref_scipy.logcdf(-1)) | ||
assert np.isclose(logp_fn(5), ref_scipy.logsf(5)) | ||
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) | ||
|
||
|
||
def test_normal_censoring(): | ||
x_rv = at.random.normal(0.5, 1) | ||
cens_x_rv = at.clip(x_rv, -2, 2) | ||
|
||
cens_x = cens_x_rv.type() | ||
|
||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x}) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([cens_x], logp) | ||
ref_scipy = st.norm(0.5, 1) | ||
|
||
assert logp_fn(-3) == -np.inf | ||
assert logp_fn(3) == -np.inf | ||
|
||
assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2)) | ||
assert np.isclose(logp_fn(2), ref_scipy.logsf(2)) | ||
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) | ||
|
||
|
||
def test_one_sided_censoring(): | ||
|
||
x_rv = at.random.normal(0, 1) | ||
lb_cens_x_rv = at.clip(x_rv, -1, x_rv) | ||
ub_cens_x_rv = at.clip(x_rv, x_rv, 1) | ||
|
||
lb_cens_x = lb_cens_x_rv.type() | ||
ub_cens_x = ub_cens_x_rv.type() | ||
|
||
lb_logp = joint_logprob(lb_cens_x_rv, {lb_cens_x_rv: lb_cens_x}) | ||
ub_logp = joint_logprob(ub_cens_x_rv, {ub_cens_x_rv: ub_cens_x}) | ||
assert_no_rvs(lb_logp) | ||
assert_no_rvs(ub_logp) | ||
|
||
logp_fn = aesara.function([lb_cens_x, ub_cens_x], [lb_logp, ub_logp]) | ||
ref_scipy = st.norm(0, 1) | ||
|
||
assert np.all(np.array(logp_fn(-2, 2)) == -np.inf) | ||
assert np.all(np.array(logp_fn(2, -2)) != -np.inf) | ||
np.testing.assert_almost_equal(logp_fn(-1, 1), ref_scipy.logcdf(-1)) | ||
np.testing.assert_almost_equal(logp_fn(1, -1), ref_scipy.logpdf(-1)) | ||
|
||
|
||
def test_failed_censoring(): | ||
|
||
x_rv = at.random.normal(0, 1) | ||
cens_x_rv = at.clip(x_rv, x_rv, x_rv) | ||
|
||
x = x_rv.type() | ||
cens_x = cens_x_rv.type() | ||
with pytest.raises(NotImplementedError): | ||
joint_logprob(cens_x_rv, {cens_x_rv: cens_x, x_rv: x}) | ||
|
||
|
||
def test_random_censoring(): | ||
|
||
lb_rv = at.random.normal(0, 1, size=2) | ||
x_rv = at.random.normal(0, 2) | ||
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1]) | ||
|
||
lb = lb_rv.type() | ||
cens_x = cens_x_rv.type() | ||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb}) | ||
assert_no_rvs(logp) | ||
|
||
logp_fn = aesara.function([lb, cens_x], logp) | ||
res = logp_fn([0, -1], [-1, -1]) | ||
assert res[0] == -np.inf | ||
assert res[1] != -np.inf | ||
|
||
|
||
@pytest.mark.xfail(reason="Broadcasting not properly handled yet") | ||
def test_broadcasted_censoring(): | ||
|
||
lb_rv = at.random.uniform(0, 1, name="lb_rv") | ||
x_rv = at.random.normal(0, 2, name="x_rv") | ||
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1]) | ||
cens_x_rv.name = "cens_x_rv" | ||
|
||
lb = lb_rv.type() | ||
lb.name = "lb" | ||
cens_x = cens_x_rv.type() | ||
cens_x.name = "cens_x" | ||
|
||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb}) | ||
assert_no_rvs(logp) | ||
|
||
|
||
@pytest.mark.xfail(reason="Broadcasting not properly handled yet") | ||
def test_broadcasted_censoring2(): | ||
|
||
lb_rv = at.random.normal(0, 1, name="lb_rv") | ||
x_rv = at.random.normal(0, 2, size=2, name="x_rv") | ||
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1]) | ||
cens_x_rv.name = "cens_x_rv" | ||
|
||
lb = lb_rv.type() | ||
lb.name = "lb" | ||
cens_x = cens_x_rv.type() | ||
cens_x.name = "cens_x" | ||
|
||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb}) | ||
assert_no_rvs(logp) |