Skip to content

Commit

Permalink
Do not replace valued RVs in naive_bcast_rv_lift
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo authored and brandonwillard committed Feb 9, 2022
1 parent da834b5 commit e0d5438
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
7 changes: 7 additions & 0 deletions aeppl/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ def naive_bcast_rv_lift(fgraph, node):
if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars:
return None # pragma: no cover

# Do not replace RV if it is associated with a value variable
rv_map_feature: Optional[PreserveRVMappings] = getattr(
fgraph, "preserve_rv_mappings", None
)
if rv_map_feature is not None and rv_var in rv_map_feature.rv_values:
return None

if not bcast_shape:
# The `BroadcastTo` is broadcasting a scalar to a scalar (i.e. doing nothing)
assert rv_var.ndim == 0
Expand Down
25 changes: 24 additions & 1 deletion tests/test_opt.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import aesara
import aesara.tensor as at
import numpy as np
import scipy.stats as st
from aesara.graph.opt import in2out
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.subtensor import Subtensor

from aeppl import factorized_joint_logprob
from aeppl.dists import DiracDelta, dirac_delta
from aeppl.opt import local_lift_DiracDelta, naive_bcast_rv_lift


def test_naive_bcast_rv_lift():
r"""Make sure `test_naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s."""
r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s."""
X_rv = at.random.normal()
Z_at = BroadcastTo()(X_rv, ())

Expand All @@ -22,6 +25,26 @@ def test_naive_bcast_rv_lift():
assert res is X_rv


def test_naive_bcast_rv_lift_valued_var():
r"""Check that `naive_bcast_rv_lift` won't touch valued variables"""

x_rv = at.random.normal(name="x")
broadcasted_x_rv = at.broadcast_to(x_rv, (2,))

y_rv = at.random.normal(broadcasted_x_rv, name="y")

x_vv = x_rv.clone()
y_vv = y_rv.clone()
logp_map = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})
assert x_vv in logp_map
assert y_vv in logp_map
assert len(logp_map) == 2
assert np.allclose(logp_map[x_vv].eval({x_vv: 0}), st.norm(0).logpdf(0))
assert np.allclose(
logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0])
)


def test_local_lift_DiracDelta():
c_at = at.vector()
dd_at = dirac_delta(c_at)
Expand Down

0 comments on commit e0d5438

Please sign in to comment.