Skip to content

Commit

Permalink
Rename _replace_rvs_in_graphs and fix bug when replacing input
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and michaelosthege committed May 17, 2023
1 parent e1060de commit c57769c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def ignore_logprob_multiple_vars(
making each "unmeasurable", whereas a sequential call to `ignore_logprob`
would not do this correctly.
"""
from pymc.pytensorf import _replace_rvs_in_graphs
from pymc.pytensorf import _replace_vars_in_graphs

measurable_vars_to_unmeasurable_vars = {
measurable_var: ignore_logprob(measurable_var) for measurable_var in vars
Expand All @@ -353,5 +353,5 @@ def replacement_fn(var, replacements):

return []

unmeasurable_vars, _ = _replace_rvs_in_graphs(graphs=vars, replacement_fn=replacement_fn)
unmeasurable_vars, _ = _replace_vars_in_graphs(graphs=vars, replacement_fn=replacement_fn)
return unmeasurable_vars
14 changes: 9 additions & 5 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,19 +205,22 @@ def expand(var):
yield from walk(graphs, expand, bfs=False)


def _replace_rvs_in_graphs(
def _replace_vars_in_graphs(
graphs: Iterable[TensorVariable],
replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]],
**kwargs,
) -> Tuple[List[TensorVariable], Dict[TensorVariable, TensorVariable]]:
"""Replace random variables in graphs
"""Replace variables in graphs.
This will *not* recompute test values.
Parameters
----------
graphs
The graphs in which random variables are to be replaced.
replacement_fn
A callable called on each graph output that populates a replacement dictionary and returns
nodes that should be investigated further.
Returns
-------
Expand Down Expand Up @@ -256,7 +259,8 @@ def expand_replace(var):
toposort = fg.toposort()
sorted_replacements = sorted(
tuple(replacements.items()),
key=lambda pair: toposort.index(pair[0].owner),
# Root inputs don't have owner, we give them negative priority -1
key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner is not None else -1,
reverse=True,
)
fg.replace_all(sorted_replacements, import_missing=True)
Expand Down Expand Up @@ -317,7 +321,7 @@ def populate_replacements(
equiv = clone_get_equiv(inputs, graphs, False, False, {})
graphs = [equiv[n] for n in graphs]

graphs, _ = _replace_rvs_in_graphs(
graphs, _ = _replace_vars_in_graphs(
graphs,
replacement_fn=populate_replacements,
**kwargs,
Expand Down Expand Up @@ -385,7 +389,7 @@ def poulate_replacements(rv, replacements):
# replacements if that is not a simple input variable
return [value]

graphs, _ = _replace_rvs_in_graphs(
graphs, _ = _replace_vars_in_graphs(
graphs,
replacement_fn=poulate_replacements,
**kwargs,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest
import scipy.sparse as sps

from pytensor import shared
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import Variable, equal_computations
from pytensor.tensor.random.basic import normal, uniform
Expand All @@ -40,6 +41,7 @@
from pymc.exceptions import NotConstantValueError
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import (
_replace_vars_in_graphs,
collect_default_updates,
compile_pymc,
constant_fold,
Expand Down Expand Up @@ -821,3 +823,21 @@ def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn):
),
[expected_x, expected_y, expected_z, expected_w],
)

def test_replace_input(self):
inp = shared(0.0, name="inp")
x = pm.Normal.dist(inp)

assert x.eval() < 50

new_inp = inp + 100

def replacement_fn(var, replacements):
if var is x:
replacements[x.owner.inputs[3]] = new_inp

return []

[new_x], _ = _replace_vars_in_graphs([x], replacement_fn=replacement_fn)

assert new_x.eval() > 50

0 comments on commit c57769c

Please sign in to comment.