diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 86293d8194e..6cfac770c05 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -596,7 +596,7 @@ class CustomSymbolicDistRV(SymbolicRandomVariable): def update(self, node: Node): op = node.op inner_updates = collect_default_updates( - op.inner_inputs, op.inner_outputs, must_be_shared=False + inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False ) # Map inner updates to outer inputs/outputs @@ -668,7 +668,7 @@ def rv_op( ): dummy_rv = dist(*dummy_dist_params, dummy_size_param) dummy_params = [dummy_size_param] + dummy_dist_params - dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,)) + dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) rv_type = type( class_name, @@ -713,7 +713,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand): dummy_dist_params = [dist_param.type() for dist_param in old_dist_params] dummy_rv = dist(*dummy_dist_params, dummy_size_param) dummy_params = [dummy_size_param] + dummy_dist_params - dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,)) + dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,)) new_rv_op = rv_type( inputs=dummy_params, outputs=[*dummy_updates_dict.values(), dummy_rv], diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index d675b6c040a..5274c01eeab 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -47,6 +47,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.scalar.basic import Cast +from pytensor.scan.op import Scan from pytensor.tensor.basic import _as_tensor_variable from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.op import RandomVariable @@ -1004,16 +1005,49 @@ def reseed_rngs( def collect_default_updates( - inputs: Sequence[Variable], outputs: Sequence[Variable], + *, + inputs: Optional[Sequence[Variable]] = None, must_be_shared: bool = True, ) -> Dict[Variable, Variable]: """Collect default update expression for shared-variable RNGs used by RVs between inputs and outputs. - If `must_be_shared` is False, update expressions will also be returned for non-shared input RNGs. - This can be useful to obtain the symbolic update expressions from inner graphs. - """ + Parameters + ---------- + outputs: list of PyTensor variables + List of variables in which graphs default updates will be collected. + inputs: list of PyTensor variables, optional + Input nodes above which default updates should not be collected. + When not provided, search will include top level inputs (roots). + must_be_shared: bool, default True + Used internally by PyMC. Whether updates should be collected for non-shared + RNG input variables. This is used to collect update expressions for inner graphs. + + Examples + -------- + .. code:: python + import pymc as pm + from pytensor.scan import scan + from pymc.pytensorf import collect_default_updates + + def scan_step(xtm1): + x = xtm1 + pm.Normal.dist() + x_update = collect_default_updates([x]) + return x, x_update + x0 = pm.Normal.dist() + + xs, updates = scan( + fn=scan_step, + outputs_info=[x0], + n_steps=10, + ) + + # PyMC makes use of the updates to seed xs properly. + # Without updates, it would raise an error. + xs_draws = pm.draw(xs, draws=10) + + """ # Avoid circular import from pymc.distributions.distribution import SymbolicRandomVariable @@ -1048,16 +1082,31 @@ def find_default_update(clients, rng: Variable) -> Union[None, Variable]: next_rng = client.op.update(client).get(rng) if next_rng is None: raise ValueError( - f"No update mapping found for RNG used in SymbolicRandomVariable Op {client.op}" + f"No update found for at least one RNG used in SymbolicRandomVariable Op {client.op}" + ) + elif isinstance(client.op, Scan): + # Check if any shared output corresponds to the RNG + rng_idx = client.inputs.index(rng) + io_map = client.op.get_oinp_iinp_iout_oout_mappings()["outer_out_from_outer_inp"] + out_idx = io_map.get(rng_idx, -1) + if out_idx != -1: + next_rng = client.outputs[out_idx] + else: # No break + raise ValueError( + f"No update found for at least one RNG used in Scan Op {client.op}.\n" + "You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically." ) else: - # We don't know how this RNG should be updated (e.g., Scan). + # We don't know how this RNG should be updated (e.g., OpFromGraph). # The user should provide an update manually return None # Recurse until we find final update for RNG return find_default_update(clients, next_rng) + if inputs is None: + inputs = [] + outputs = makeiter(outputs) fg = FunctionGraph(outputs=outputs, clone=False) clients = fg.clients @@ -1129,7 +1178,7 @@ def compile_pymc( """ # Create an update mapping of RandomVariable's RNG so that it is automatically # updated after every function call - rng_updates = collect_default_updates(inputs, outputs) + rng_updates = collect_default_updates(inputs=inputs, outputs=outputs) # We always reseed random variables as this provides RNGs with no chances of collision if rng_updates: diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index d55ee79826f..04b9bf94cd3 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -23,6 +23,7 @@ import pytest import scipy.stats as st +from pytensor import scan from pytensor.tensor import TensorVariable import pymc as pm @@ -51,6 +52,7 @@ from pymc.logprob.abstract import get_measurable_outputs from pymc.logprob.basic import logcdf, logp from pymc.model import Deterministic, Model +from pymc.pytensorf import collect_default_updates from pymc.sampling import draw, sample from pymc.testing import ( BaseTestDistributionRandom, @@ -523,6 +525,48 @@ def old_random(size): # New API is fine pm.CustomDist.dist(dist=old_random, class_name="custom_dist") + def test_scan(self): + def trw(nu, sigma, steps, size): + def step(xtm1, nu, sigma): + x = pm.StudentT.dist(nu=nu, mu=xtm1, sigma=sigma, shape=size) + return x, collect_default_updates([x]) + + xs, _ = scan( + fn=step, + outputs_info=pt.zeros(size), + non_sequences=[nu, sigma], + n_steps=steps, + ) + + # Logprob inference cannot be derived yet https://github.com/pymc-devs/pymc/issues/6360 + # xs = swapaxes(xs, 0, -1) + + return xs + + nu = 4 + sigma = 0.7 + steps = 99 + batch_size = 3 + x = CustomDist.dist(nu, sigma, steps, dist=trw, size=batch_size) + + x_draw = pm.draw(x, random_seed=1) + assert x_draw.shape == (steps, batch_size) + np.testing.assert_allclose(pm.draw(x, random_seed=1), x_draw) + assert not np.any(pm.draw(x, random_seed=2) == x_draw) + + ref_dist = pm.RandomWalk.dist( + init_dist=pm.Flat.dist(), + innovation_dist=pm.StudentT.dist(nu=nu, sigma=sigma), + steps=steps, + size=(batch_size,), + ) + ref_val = pt.concatenate([np.zeros((1, batch_size)), x_draw]).T + + np.testing.assert_allclose( + pm.logp(x, x_draw).eval().sum(0), + pm.logp(ref_dist, ref_val).eval(), + ) + class TestSymbolicRandomVariable: def test_inline(self): diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 73912668a11..dc1852966da 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -24,7 +24,7 @@ import pytest import scipy.sparse as sps -from pytensor import shared +from pytensor import scan, shared from pytensor.compile.builders import OpFromGraph from pytensor.graph.basic import Variable, equal_computations from pytensor.tensor.random.basic import normal, uniform @@ -465,7 +465,7 @@ def update(self, node): ], )(rng1, rng2) with pytest.raises( - ValueError, match="No update mapping found for RNG used in SymbolicRandomVariable" + ValueError, match="No update found for at least one RNG used in SymbolicRandomVariable" ): compile_pymc(inputs=[], outputs=[dummy_x1, dummy_x2]) @@ -531,7 +531,7 @@ def test_nested_updates(self): next_rng2, y = pt.random.normal(rng=next_rng1).owner.outputs next_rng3, z = pt.random.normal(rng=next_rng2).owner.outputs - collect_default_updates([], [x, y, z]) == {rng: next_rng3} + collect_default_updates(inputs=[], outputs=[x, y, z]) == {rng: next_rng3} fn = compile_pymc([], [x, y, z], random_seed=514) assert not set(list(np.array(fn()))) & set(list(np.array(fn()))) @@ -540,19 +540,56 @@ def test_nested_updates(self): fn = pytensor.function([], [x, y, z], updates={rng: next_rng1}) assert set(list(np.array(fn()))) & set(list(np.array(fn()))) + def test_collect_default_updates_must_be_shared(self): + shared_rng = pytensor.shared(np.random.default_rng()) + nonshared_rng = shared_rng.type() -def test_collect_default_updates_must_be_shared(): - shared_rng = pytensor.shared(np.random.default_rng()) - nonshared_rng = shared_rng.type() + next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs + next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs - next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs - next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs + res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y]) + assert res == {shared_rng: next_rng_of_shared} - res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y]) - assert res == {shared_rng: next_rng_of_shared} + res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False) + assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared} - res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False) - assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared} + def test_scan_updates(self): + def step_with_update(x, rng): + next_rng, x = pm.Normal.dist(x, rng=rng).owner.outputs + return x, {rng: next_rng} + + def step_wo_update(x, rng): + return step_with_update(x, rng)[0] + + rng = pytensor.shared(np.random.default_rng()) + + xs, next_rng = scan( + fn=step_wo_update, + outputs_info=[pt.zeros(())], + non_sequences=[rng], + n_steps=10, + name="test_scan", + ) + + assert not next_rng + + with pytest.raises( + ValueError, + match=r"No update found for at least one RNG used in Scan Op for\{cpu,test_scan\}", + ): + collect_default_updates([xs]) + + ys, next_rng = scan( + fn=step_with_update, + outputs_info=[pt.zeros(())], + non_sequences=[rng], + n_steps=10, + ) + + assert collect_default_updates([ys]) == {rng: tuple(next_rng.values())[0]} + + fn = compile_pymc([], ys, random_seed=1) + assert not (set(fn()) & set(fn())) def test_replace_rng_nodes():