From 170f7691341f5d7d46bab765caca634e35432c52 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 31 Jan 2022 19:36:24 +0100 Subject: [PATCH] Do not replace valued RVs in `naive_bcast_rv_lift` --- aeppl/opt.py | 7 +++++++ tests/test_opt.py | 25 ++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/aeppl/opt.py b/aeppl/opt.py index c7b858f4..e1c91fbe 100644 --- a/aeppl/opt.py +++ b/aeppl/opt.py @@ -219,6 +219,13 @@ def naive_bcast_rv_lift(fgraph, node): if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars: return None # pragma: no cover + # Do not replace RV if it is associated with a value variable + rv_map_feature: Optional[PreserveRVMappings] = getattr( + fgraph, "preserve_rv_mappings", None + ) + if rv_map_feature is not None and rv_var in rv_map_feature.rv_values: + return None + if not bcast_shape: # The `BroadcastTo` is broadcasting a scalar to a scalar (i.e. doing nothing) assert rv_var.ndim == 0 diff --git a/tests/test_opt.py b/tests/test_opt.py index bd60d3f2..24836717 100644 --- a/tests/test_opt.py +++ b/tests/test_opt.py @@ -1,17 +1,20 @@ import aesara import aesara.tensor as at +import numpy as np +import scipy.stats as st from aesara.graph.opt import in2out from aesara.graph.opt_utils import optimize_graph from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.extra_ops import BroadcastTo from aesara.tensor.subtensor import Subtensor +from aeppl import factorized_joint_logprob from aeppl.dists import DiracDelta, dirac_delta from aeppl.opt import local_lift_DiracDelta, naive_bcast_rv_lift def test_naive_bcast_rv_lift(): - r"""Make sure `test_naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s.""" + r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s.""" X_rv = at.random.normal() Z_at = BroadcastTo()(X_rv, ()) @@ -22,6 +25,26 @@ def test_naive_bcast_rv_lift(): assert res is X_rv +def test_naive_bcast_rv_lift_valued_var(): + r"""Check that `naive_bcast_rv_lift` won't touch valued variables""" + + x_rv = at.random.normal(name="x") + broadcasted_x_rv = at.broadcast_to(x_rv, (2,)) + + y_rv = at.random.normal(broadcasted_x_rv, name="y") + + x_vv = x_rv.clone() + y_vv = y_rv.clone() + logp_map = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv}) + assert x_vv in logp_map + assert y_vv in logp_map + assert len(logp_map) == 2 + assert np.allclose(logp_map[x_vv].eval({x_vv: 0}), st.norm(0).logpdf(0)) + assert np.allclose( + logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0]) + ) + + def test_local_lift_DiracDelta(): c_at = at.vector() dd_at = dirac_delta(c_at)