-
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement RV censoring logprob and opt
- Loading branch information
Ricardo
committed
Jun 22, 2021
1 parent
e3930a2
commit e16ef75
Showing
3 changed files
with
141 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from typing import List, Optional | ||
|
||
import aesara.tensor as at | ||
import numpy as np | ||
from aesara.assert_op import Assert | ||
from aesara.graph.basic import Node | ||
from aesara.graph.fg import FunctionGraph | ||
from aesara.graph.opt import local_optimizer | ||
from aesara.scalar.basic import Clip | ||
from aesara.tensor.elemwise import Elemwise | ||
from aesara.tensor.random.op import RandomVariable | ||
|
||
from aeppl.logprob import _logcdf, _logprob | ||
|
||
|
||
# TODO: Add interval transform | ||
class CensoredRV(RandomVariable): | ||
r"""A base class for censored `RandomVariable`\s.""" | ||
|
||
def __init__(self, *args, base_op, **kwargs): | ||
self.base_op = base_op | ||
super().__init__(*args, **kwargs) | ||
|
||
|
||
@_logprob.register(CensoredRV) | ||
def censor_logprob(op, value, *inputs, name=None, **kwargs): | ||
|
||
*rv_params, lower_bound, upper_bound = inputs | ||
logprob = _logprob(op.base_op, value, *rv_params, name=name, **kwargs) | ||
logcdf = _logcdf(op.base_op, value, *rv_params, name=name, **kwargs) | ||
|
||
# TODO: Check constant -inf | ||
# TODO: Exact bound check might be problematic | ||
res = at.switch( | ||
at.eq(value, lower_bound), | ||
logcdf, | ||
at.switch( | ||
at.eq(value, upper_bound), | ||
at.log(1 - at.exp(logcdf)), | ||
logprob, | ||
), | ||
) | ||
res = at.switch( | ||
at.or_(at.lt(value, lower_bound), at.gt(value, upper_bound)), | ||
-np.inf, | ||
res, | ||
) | ||
|
||
res = Assert("lower_bound <= upper_bound")( | ||
res, at.all(at.le(lower_bound, upper_bound)) | ||
) | ||
|
||
if name: | ||
res.name = f"{name}_censored_logprob" | ||
|
||
return res | ||
|
||
|
||
@local_optimizer(tracks=[Elemwise]) | ||
def censor_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: | ||
|
||
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) | ||
|
||
if rv_map_feature is None: | ||
return | ||
|
||
# TODO: Allow for one sided-truncation with set_subtensor graph (x[x>ub] = ub) | ||
if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, Clip): | ||
return | ||
|
||
base_var, lower_bound, upper_bound = node.inputs | ||
base_op = base_var.owner.op | ||
|
||
if not isinstance(base_op, RandomVariable): | ||
return | ||
|
||
# clipped_var = node.ouputs[0] | ||
# clipped_value = rv_map_feature.rv_values.pop(clipped_var) | ||
|
||
censored_rv = CensoredRV( | ||
"censored", | ||
base_op.ndim_supp, | ||
list(base_op.ndims_params) + [0, 0], | ||
base_op.dtype, | ||
inplace=False, | ||
base_op=base_op, | ||
) | ||
censored_node = censored_rv.make_node( | ||
*base_var.owner.inputs, | ||
lower_bound, | ||
upper_bound, | ||
) | ||
|
||
# assert censored_node.outputs[1].type == clipped_var.type | ||
return [censored_node.outputs[1]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import aesara | ||
import aesara.tensor as at | ||
import numpy as np | ||
import scipy.stats as st | ||
|
||
from aeppl import joint_logprob | ||
|
||
|
||
def test_uniform_censoring(): | ||
x_rv = at.random.uniform(-4, 5) | ||
cens_x_rv = at.clip(x_rv, -1, 1) | ||
cens_x = cens_x_rv.type() | ||
|
||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x}) | ||
logp_fn = aesara.function([cens_x], logp) | ||
|
||
ref_scipy = st.uniform(-4, 9) | ||
|
||
assert logp_fn(-5) == -np.inf | ||
assert logp_fn(6) == -np.inf | ||
|
||
assert np.isclose(logp_fn(-1), ref_scipy.logcdf(-1)) | ||
assert np.isclose(logp_fn(5), ref_scipy.logsf(5)) | ||
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) | ||
|
||
|
||
def test_normal_censoring(): | ||
x_rv = at.random.normal(0.5, 1) | ||
cens_x_rv = at.clip(x_rv, -2, 2) | ||
cens_x = cens_x_rv.type() | ||
|
||
logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x}) | ||
logp_fn = aesara.function([cens_x], logp) | ||
|
||
ref_scipy = st.norm(0.5, 1) | ||
|
||
assert logp_fn(-3) == -np.inf | ||
assert logp_fn(3) == -np.inf | ||
|
||
assert np.isclose(logp_fn(-2), ref_scipy.logcdf(-2)) | ||
assert np.isclose(logp_fn(2), ref_scipy.logsf(2)) | ||
assert np.isclose(logp_fn(0), ref_scipy.logpdf(0)) |