Skip to content

Commit

Permalink
Investigate using OpFromGraph for CensoredRVs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Sep 9, 2021
1 parent 1a46bb9 commit 3600f40
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 40 deletions.
68 changes: 28 additions & 40 deletions aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,50 @@
import warnings
from typing import List, Optional, Type
from typing import List, Optional

import aesara.tensor as at
import numpy as np
from aesara.assert_op import Assert
from aesara.compile.builders import OpFromGraph
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import local_optimizer
from aesara.graph.utils import MetaType
from aesara.scalar.basic import Clip
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant

from aeppl.abstract import MeasurableVariable
from aeppl.logprob import _logcdf, _logprob


class CensoredRVMeta(MetaType):
pass
class CensoredRV(OpFromGraph):
"""A placeholder used to specify a log-likelihood for a censored RV sub-graph."""

@classmethod
def create_node(cls, node, base_rv, lb, ub):
out_var = node.default_output()
inputs = [base_rv, lb, ub]

class CensoredRV(RandomVariable, metaclass=CensoredRVMeta):
r"""A base class for censored `RandomVariable`\s."""

def __init__(self):
super().__init__(
"censored",
self.base_op.ndim_supp,
list(self.base_op.ndims_params) + [self.base_op.ndim_supp] * 2,
self.base_op.dtype,
inplace=False,
censored_op = cls(
inputs,
[out_var],
inline=True,
on_unused_input="ignore",
)

op_name = base_rv.owner.op.name
if op_name:
censored_op.name = f"censored_{op_name}"

def _create_censored_rv_op(rv_op: Op) -> Type[CensoredRV]:
"""Create a new `CensoredRV` given a base `RandomVariable` `Op`
Parameters
==========
rv_op
The `RandomVariable` for which we want to construct a `TransformedRV`.
"""
# new_node = mixture_op.make_node(None, None, None, *inputs)
new_node = censored_op(*inputs)
return new_node.owner

rv_type_name = type(rv_op).__name__
cls_dict = type(rv_op).__dict__.copy()
rv_name = cls_dict.get("name", "")
if rv_name:
cls_dict["name"] = f"Censored{rv_name}"
cls_dict["base_op"] = rv_op
def get_non_shared_inputs(self, inputs):
return inputs[: -len(self.shared_inputs)]

new_op_type = type(f"censored_{rv_type_name}", (CensoredRV,), cls_dict)

return new_op_type
MeasurableVariable.register(CensoredRV)


@local_optimizer(tracks=[Elemwise])
Expand Down Expand Up @@ -93,13 +85,8 @@ def find_censored_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Censor
lower_bound = lower_bound if (lower_bound is not base_var) else at.constant(-np.inf)
upper_bound = upper_bound if (upper_bound is not base_var) else at.constant(np.inf)

censored_rv = _create_censored_rv_op(base_var.owner.op)()
censored_node = censored_rv.make_node(
*base_var.owner.inputs,
lower_bound,
upper_bound,
)
censored_rv = censored_node.outputs[1]
censored_rv_node = CensoredRV.create_node(node, base_var, lower_bound, upper_bound)
censored_rv = censored_rv_node.default_output()

if clipped_var.name:
censored_rv.name = clipped_var.name
Expand All @@ -115,8 +102,9 @@ def find_censored_rvs(fgraph: FunctionGraph, node: Node) -> Optional[List[Censor
def censor_logprob(op: CensoredRV, values, *inputs, **kwargs):
(value,) = values

*base_rv_inputs, lower_bound, upper_bound = inputs
base_rv_op = op.base_op
base_rv, lower_bound, upper_bound = op.get_non_shared_inputs(inputs)
base_rv_op = base_rv.owner.op
base_rv_inputs = base_rv.owner.inputs

logprob = _logprob(base_rv_op, (value,), *base_rv_inputs, **kwargs)
logcdf = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def test_TransformValuesMapping():
assert fg._features[-1] is tvm


@pytest.mark.xfail(reason="Transform does not work with OpFromGraph")
def test_censored_transform():
x_rv = at.random.normal(0.5, 1, name="x_rv")
cens_x_rv = at.clip(x_rv, 0, x_rv)
Expand Down
1 change: 1 addition & 0 deletions tests/test_truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def test_deterministic_clipping():
)


@pytest.mark.xfail(reason="unclear")
@aesara.config.change_flags(compute_test_value="raise")
def test_censored_test_value():
x_rv = at.random.normal(0, 1)
Expand Down

0 comments on commit 3600f40

Please sign in to comment.