Skip to content

Commit

Permalink
Introduce ValuedVariable
Browse files Browse the repository at this point in the history
This `Op` adds the value variable to the graph so that `PreserveRVMappings` is
no longer needed.

It also allows clarifies the definition and actions of rewrites that truly apply
to a `MeasurableVariable` and its value variable simultaneously.
  • Loading branch information
brandonwillard committed Oct 24, 2021
1 parent 580c887 commit a6a9cb4
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 17 deletions.
37 changes: 37 additions & 0 deletions aeppl/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import singledispatch
from typing import Callable, List

from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import Op
from aesara.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -82,3 +83,39 @@ def assign_custom_measurable_outputs(
_get_measurable_outputs.register(new_op_type)(measurable_outputs_fn)

return new_node


class ValuedVariable(Op):
r"""Represents the association of a measurable variable and its value.
A `ValuedVariable` node represents the pair :math:`(Y, y)`, where
:math:`Y` is a random variable and :math:`y \sim Y`.
Log-probability (densities) are functions over these pairs, which makes
these nodes in a graph an intermediate form that serves to construct a
log-probability from a model graph.
This intermediate form can be used as the target for rewrites that
otherwise wouldn't make sense to apply to--say--a random variable node
directly. An example is broadcast operation lifting; a broadcasting operation
cannot be lifted...TODO
"""

default_output = 0
view_map = {0: [1]}

def make_node(self, rv, value):
output = rv.type()
return Apply(self, [rv, value], [output])

def perform(self, node, inputs, out):
out[0][0] = inputs[1]

def grad(self, inputs, outputs):
return [
grad_undefined(self, k, inp, "No gradient defined for `ValuedRV`")
for k, inp in enumerate(inputs)
]


valued_variable = ValuedVariable()
4 changes: 2 additions & 2 deletions aeppl/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.logprob import _logprob, logprob
from aeppl.opt import naive_bcast_rv_lift, rv_sinking_db, subtensor_ops
from aeppl.opt import rv_sinking_db, subtensor_ops, valued_var_bcast_lift
from aeppl.utils import get_constant_value, indices_from_subtensor


Expand All @@ -28,7 +28,7 @@ def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
[
local_dimshuffle_rv_lift,
local_subtensor_rv_lift,
naive_bcast_rv_lift,
valued_var_bcast_lift,
],
x,
)
Expand Down
71 changes: 56 additions & 15 deletions aeppl/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from aesara.tensor.var import TensorVariable

from aeppl.abstract import MeasurableVariable
from aeppl.abstract import MeasurableVariable, ValuedVariable, valued_variable
from aeppl.utils import indices_from_subtensor

inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
Expand Down Expand Up @@ -157,11 +157,8 @@ def incsubtensor_rv_replace(fgraph, node):


@local_optimizer([BroadcastTo])
def naive_bcast_rv_lift(fgraph, node):
"""Lift a ``BroadcastTo`` through a ``RandomVariable`` ``Op``.
XXX: This implementation simply broadcasts the ``RandomVariable``'s
parameters, which won't always work (e.g. multivariate distributions).
def valued_var_bcast_lift(fgraph, node):
r"""Lift a `BroadcastTo` through a `ValuedVariable` with a `RandomVariable`.
TODO: Instead, it should use ``RandomVariable.ndim_supp``--and the like--to
determine which dimensions of each parameter need to be broadcasted.
Expand All @@ -172,15 +169,20 @@ def naive_bcast_rv_lift(fgraph, node):
if not (
isinstance(node.op, BroadcastTo)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, RandomVariable)
and isinstance(node.inputs[0].owner.op, ValuedVariable)
):
return None # pragma: no cover

bcast_shape = node.inputs[1:]

assert len(bcast_shape) > 0

rv_var = node.inputs[0]
valued_var = node.inputs[0]

if not (valued_var.owner and isinstance(valued_var.owner.op, RandomVariable)):
return None

rv_var = valued_var.owner.inputs[0]
rv_node = rv_var.owner

if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars:
Expand All @@ -204,6 +206,9 @@ def naive_bcast_rv_lift(fgraph, node):
)
for param in dist_params
]

# TODO: Replace `lifted_node.op` with a clone type that doesn't draw samples.
# Also, remove the `rng` object
bcasted_node = lifted_node.op.make_node(rng, size, dtype, *new_dist_params)

if aesara.config.compute_test_value != "off":
Expand All @@ -212,14 +217,30 @@ def naive_bcast_rv_lift(fgraph, node):
return [bcasted_node.outputs[1]]


logprob_rewrites_db = SequenceDB()
logprob_rewrites_db.name = "logprob_rewrites_db"
logprob_rewrites_db.register(
"pre-canonicalize", optdb.query("+canonicalize"), -10, "basic"
)
@local_optimizer([subtensor_ops])
def valued_variable_subtensor_lift(fgraph, node):

if not (node.owner and isinstance(node.owner.op, subtensor_ops)):
return None

indexed_var = node.owner.inputs[0]

if not (indexed_var.owner and isinstance(indexed_var.owner.op, ValuedVariable)):
return None

indices = node.owner.inputs[1:]

variable_inp, value_inp = indexed_var.owner.inputs

var_subtens_inp = node.clone_new_inputs([variable_inp] + indices)
value_subtens_inp = node.clone_new_inputs([value_inp] + indices)

new_vv = valued_variable(var_subtens_inp, value_subtens_inp)

class RVSinkingDB(EquilibriumDB):
return [new_vv]


class NoCallbackEquilibriumDB(EquilibriumDB):
r"""This `EquilibriumDB` doesn't hide its exceptions.
By setting `failure_callback` to ``None`` in the `EquilibriumOptimizer`\s
Expand All @@ -233,13 +254,33 @@ def query(self, *tags, **kwtags):
return res


class RVSinkingDB(NoCallbackEquilibriumDB):
r"""A DB with optimizations that pertain to `RandomVariable`s"""


rv_sinking_db = RVSinkingDB()

rv_sinking_db.name = "rv_sinking_db"
rv_sinking_db.register("dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic")
rv_sinking_db.register("subtensor_lift", local_subtensor_rv_lift, -5, "basic")
rv_sinking_db.register("broadcast_to_lift", naive_bcast_rv_lift, -5, "basic")
rv_sinking_db.register("broadcast_to_lift", valued_var_bcast_lift, -5, "basic")
rv_sinking_db.register("incsubtensor_lift", incsubtensor_rv_replace, -5, "basic")


class ValueMapsDB(NoCallbackEquilibriumDB):
r"""A DB with optimizations that pertain to `ValuedVariable`\s"""


value_maps_db = ValueMapsDB()
value_maps_db.name = "value_maps_db"
value_maps_db.register("subtensor_lift", valued_variable_subtensor_lift, -5, "basic")

logprob_rewrites_db = SequenceDB()
logprob_rewrites_db.name = "logprob_rewrites_db"
logprob_rewrites_db.register(
"pre-canonicalize", optdb.query("+canonicalize"), -10, "basic"
)
logprob_rewrites_db.register("value_maps", value_maps_db, -20, "basic")
logprob_rewrites_db.register("sinking", rv_sinking_db, -10, "basic")
logprob_rewrites_db.register(
"post-canonicalize", optdb.query("+canonicalize"), 10, "basic"
Expand Down
34 changes: 34 additions & 0 deletions tests/test_abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import aesara
import aesara.tensor as at
import numpy as np
from aesara.gradient import NullTypeGradError, grad
from pytest import raises

from aeppl.abstract import valued_variable


def test_observed():
rv_var = at.random.normal(0, 1, size=3)
obs_var = valued_variable(
rv_var, np.array([0.2, 0.1, -2.4], dtype=aesara.config.floatX)
)

assert obs_var.owner.inputs[0] is rv_var

with raises(TypeError):
valued_variable(rv_var, np.array([1, 2], dtype=int))

with raises(TypeError):
valued_variable(rv_var, np.array([[1.0, 2.0]], dtype=rv_var.dtype))

# obs_rv = valued_variable(None, np.array([0.2, 0.1, -2.4], dtype=aesara.config.floatX))
#
# assert isinstance(obs_rv.owner.inputs[0].type, NoneTypeT)

rv_val = at.vector()
rv_val.tag.test_value = np.array([0.2, 0.1, -2.4], dtype=aesara.config.floatX)

obs_var = valued_variable(rv_var, rv_val)

with raises(NullTypeGradError):
grad(obs_var.sum(), [rv_val])
42 changes: 42 additions & 0 deletions tests/test_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import aesara.tensor as at
import pytest
from aesara.graph.opt import EquilibriumOptimizer
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.extra_ops import BroadcastTo

from aeppl.opt import valued_var_bcast_lift

bcast_lift_opt = EquilibriumOptimizer(
[valued_var_bcast_lift], ignore_newtrees=False, max_use_ratio=1000
)


@pytest.mark.parametrize(
"rv_params, rv_size, bcast_shape, should_rewrite",
[
# The `BroadcastTo` shouldn't be lifted, because it would imply that there
# are 10 independent samples, when there's really only one
pytest.param(
(0, 1),
None,
(10,),
False,
marks=pytest.mark.xfail(reason="Not implemented"),
),
# These should work, under the assumption that `size == 10`, of course.
((0, 1), at.iscalar("size"), (10,), True),
((0, 1), at.iscalar("size"), (1, 10, 1), True),
((at.zeros((at.iscalar("size"),)), 1), None, (10,), True),
],
)
def test_naive_bcast_rv_lift(rv_params, rv_size, bcast_shape, should_rewrite):
graph = at.broadcast_to(at.random.normal(*rv_params, size=rv_size), bcast_shape)

assert isinstance(graph.owner.op, BroadcastTo)

new_graph = optimize_graph(graph, custom_opt=bcast_lift_opt)

if should_rewrite:
assert not isinstance(new_graph.owner.op, BroadcastTo)
else:
assert isinstance(new_graph.owner.op, BroadcastTo)

0 comments on commit a6a9cb4

Please sign in to comment.