Skip to content

Commit

Permalink
Track valued/bound variables using an in-graph Op
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 25, 2022
1 parent 038adf5 commit 0b28a51
Show file tree
Hide file tree
Showing 19 changed files with 558 additions and 620 deletions.
2 changes: 1 addition & 1 deletion aeppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from aeppl.logprob import logprob # isort: split

from aeppl.joint_logprob import conditional_logprob, joint_logprob
from aeppl.joint_logprob import DensityNotFound, conditional_logprob, joint_logprob
from aeppl.printing import latex_pprint, pprint

# isort: off
Expand Down
60 changes: 59 additions & 1 deletion aeppl/abstract.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import abc
from copy import copy
from functools import singledispatch
from typing import Callable, List
from typing import TYPE_CHECKING, Callable, List

import aesara.tensor as at
from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import Op
from aesara.graph.utils import MetaType
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.type import TensorType

if TYPE_CHECKING:
pass


class MeasurableVariable(abc.ABC):
Expand Down Expand Up @@ -124,3 +130,55 @@ class MeasurableElemwise(Elemwise):


MeasurableVariable.register(MeasurableElemwise)


class ValuedVariable(Op):
r"""Represents the association of a measurable variable and its value.
A `ValuedVariable` node represents the pair :math:`(Y, y)`, where
:math:`Y` is a random variable and :math:`y \sim Y`.
Log-probability (densities) are functions over these pairs, which makes
these nodes in a graph an intermediate form that serves to construct a
log-probability from a model graph.
This intermediate form can be used as the target for rewrites that
otherwise wouldn't make sense to apply to--say--a random variable node
directly. An example is `BroadcastTo` lifting through `RandomVariable`\s.
"""

default_output = 0
view_map = {0: [0]}

def make_node(self, rv, value):

assert isinstance(rv.type, TensorType)
out_rv = rv.type()

vv = at.as_tensor_variable(value)
assert isinstance(vv.type, TensorType)

# TODO: We should probably check the `Type`s of `out_rv` and `vv`
if vv.type.dtype != rv.type.dtype:
raise TypeError(
f"Value type {vv.type} does not match random variable type {out_rv.type}"
)

return Apply(self, [rv, vv], [out_rv])

def perform(self, node, inputs, out):
out[0][0] = inputs[0]

def grad(self, inputs, outputs):
return [
grad_undefined(self, k, inp, "No gradient defined for `ValuedVariable`")
for k, inp in enumerate(inputs)
]

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]


MeasurableVariable.register(ValuedVariable)

valued_variable = ValuedVariable()
11 changes: 5 additions & 6 deletions aeppl/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from aeppl.abstract import (
MeasurableElemwise,
MeasurableVariable,
ValuedVariable,
assign_custom_measurable_outputs,
)
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, logdiffexp
Expand All @@ -37,8 +38,7 @@ def find_measurable_clips(
) -> Optional[List["Variable"]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
if isinstance(node.op, MeasurableClip):
return None # pragma: no cover

clipped_var = node.outputs[0]
Expand All @@ -47,7 +47,7 @@ def find_measurable_clips(
if not (
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
and not isinstance(base_var, ValuedVariable)
):
return None

Expand Down Expand Up @@ -190,8 +190,7 @@ def construct_measurable_rounding(
fgraph: FunctionGraph, node: Node, rounded_op: "Op"
) -> Optional[List["Variable"]]:

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
if isinstance(node.op, MeasurableRound):
return None # pragma: no cover

(rounded_var,) = node.outputs
Expand All @@ -200,7 +199,7 @@ def construct_measurable_rounding(
if not (
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
and not isinstance(base_var, ValuedVariable)
# Rounding only makes sense for continuous variables
and base_var.dtype.startswith("float")
):
Expand Down
17 changes: 7 additions & 10 deletions aeppl/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from aesara.graph.rewriting.basic import node_rewriter
from aesara.tensor.extra_ops import CumOp

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.abstract import (
MeasurableVariable,
ValuedVariable,
assign_custom_measurable_outputs,
)
from aeppl.logprob import _logprob, logprob
from aeppl.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from aeppl.rewriting import measurable_ir_rewrites_db


class MeasurableCumsum(CumOp):
Expand Down Expand Up @@ -50,20 +54,13 @@ def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]:
if isinstance(node.op, MeasurableCumsum):
return None # pragma: no cover

rv_map_feature: Optional[PreserveRVMappings] = getattr(
fgraph, "preserve_rv_mappings", None
)

if rv_map_feature is None:
return None # pragma: no cover

rv = node.outputs[0]

base_rv = node.inputs[0]
if not (
base_rv.owner
and isinstance(base_rv.owner.op, MeasurableVariable)
and base_rv not in rv_map_feature.rv_values
and not isinstance(base_rv, ValuedVariable)
):
return None # pragma: no cover

Expand Down
Loading

0 comments on commit 0b28a51

Please sign in to comment.