Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not replace valued RVs in naive_bcast_rv_lift #116

Merged
merged 2 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions aeppl/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ def naive_bcast_rv_lift(fgraph, node):
if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars:
return None # pragma: no cover

# Do not replace RV if it is associated with a value variable
rv_map_feature: Optional[PreserveRVMappings] = getattr(
fgraph, "preserve_rv_mappings", None
)
if rv_map_feature is not None and rv_var in rv_map_feature.rv_values:
return None

if not bcast_shape:
# The `BroadcastTo` is broadcasting a scalar to a scalar (i.e. doing nothing)
assert rv_var.ndim == 0
Expand Down
5 changes: 0 additions & 5 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import sys
import os
import pathlib


# import local version of library instead of installed one
sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve().parent.parent / "src"))
import aeppl

# -- Project information
Expand Down
25 changes: 24 additions & 1 deletion tests/test_opt.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import aesara
import aesara.tensor as at
import numpy as np
import scipy.stats as st
from aesara.graph.opt import in2out
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.subtensor import Subtensor

from aeppl import factorized_joint_logprob
from aeppl.dists import DiracDelta, dirac_delta
from aeppl.opt import local_lift_DiracDelta, naive_bcast_rv_lift


def test_naive_bcast_rv_lift():
r"""Make sure `test_naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s."""
r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `BroadcastTo`\s."""
X_rv = at.random.normal()
Z_at = BroadcastTo()(X_rv, ())

Expand All @@ -22,6 +25,26 @@ def test_naive_bcast_rv_lift():
assert res is X_rv


def test_naive_bcast_rv_lift_valued_var():
r"""Check that `naive_bcast_rv_lift` won't touch valued variables"""

x_rv = at.random.normal(name="x")
broadcasted_x_rv = at.broadcast_to(x_rv, (2,))

y_rv = at.random.normal(broadcasted_x_rv, name="y")

x_vv = x_rv.clone()
y_vv = y_rv.clone()
logp_map = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})
assert x_vv in logp_map
assert y_vv in logp_map
assert len(logp_map) == 2
assert np.allclose(logp_map[x_vv].eval({x_vv: 0}), st.norm(0).logpdf(0))
assert np.allclose(
logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0])
)


def test_local_lift_DiracDelta():
c_at = at.vector()
dd_at = dirac_delta(c_at)
Expand Down