Skip to content

Commit

Permalink
Handle Scan in collect_default_updates
Browse files Browse the repository at this point in the history
This allows proper seeding in CustomDists with Scans
  • Loading branch information
ricardoV94 committed May 19, 2023
1 parent c57769c commit 5532cb2
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 22 deletions.
6 changes: 3 additions & 3 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ class CustomSymbolicDistRV(SymbolicRandomVariable):
def update(self, node: Node):
op = node.op
inner_updates = collect_default_updates(
op.inner_inputs, op.inner_outputs, must_be_shared=False
inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False
)

# Map inner updates to outer inputs/outputs
Expand Down Expand Up @@ -668,7 +668,7 @@ def rv_op(
):
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param] + dummy_dist_params
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))

rv_type = type(
class_name,
Expand Down Expand Up @@ -713,7 +713,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
dummy_dist_params = [dist_param.type() for dist_param in old_dist_params]
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param] + dummy_dist_params
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
new_rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
Expand Down
63 changes: 56 additions & 7 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.scalar.basic import Cast
from pytensor.scan.op import Scan
from pytensor.tensor.basic import _as_tensor_variable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -1004,16 +1005,49 @@ def reseed_rngs(


def collect_default_updates(
inputs: Sequence[Variable],
outputs: Sequence[Variable],
*,
inputs: Optional[Sequence[Variable]] = None,
must_be_shared: bool = True,
) -> Dict[Variable, Variable]:
"""Collect default update expression for shared-variable RNGs used by RVs between inputs and outputs.
If `must_be_shared` is False, update expressions will also be returned for non-shared input RNGs.
This can be useful to obtain the symbolic update expressions from inner graphs.
"""
Parameters
----------
outputs: list of PyTensor variables
List of variables in which graphs default updates will be collected.
inputs: list of PyTensor variables, optional
Input nodes above which default updates should not be collected.
When not provided, search will include top level inputs (roots).
must_be_shared: bool, default True
Used internally by PyMC. Whether updates should be collected for non-shared
RNG input variables. This is used to collect update expressions for inner graphs.
Examples
--------
.. code:: python
import pymc as pm
from pytensor.scan import scan
from pymc.pytensorf import collect_default_updates
def scan_step(xtm1):
x = xtm1 + pm.Normal.dist()
x_update = collect_default_updates([x])
return x, x_update
x0 = pm.Normal.dist()
xs, updates = scan(
fn=scan_step,
outputs_info=[x0],
n_steps=10,
)
# PyMC makes use of the updates to seed xs properly.
# Without updates, it would raise an error.
xs_draws = pm.draw(xs, draws=10)
"""
# Avoid circular import
from pymc.distributions.distribution import SymbolicRandomVariable

Expand Down Expand Up @@ -1048,16 +1082,31 @@ def find_default_update(clients, rng: Variable) -> Union[None, Variable]:
next_rng = client.op.update(client).get(rng)
if next_rng is None:
raise ValueError(
f"No update mapping found for RNG used in SymbolicRandomVariable Op {client.op}"
f"No update found for at least one RNG used in SymbolicRandomVariable Op {client.op}"
)
elif isinstance(client.op, Scan):
# Check if any shared output corresponds to the RNG
rng_idx = client.inputs.index(rng)
io_map = client.op.get_oinp_iinp_iout_oout_mappings()["outer_out_from_outer_inp"]
out_idx = io_map.get(rng_idx, -1)
if out_idx != -1:
next_rng = client.outputs[out_idx]
else: # No break
raise ValueError(
f"No update found for at least one RNG used in Scan Op {client.op}.\n"
"You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically."
)
else:
# We don't know how this RNG should be updated (e.g., Scan).
# We don't know how this RNG should be updated (e.g., OpFromGraph).
# The user should provide an update manually
return None

# Recurse until we find final update for RNG
return find_default_update(clients, next_rng)

if inputs is None:
inputs = []

outputs = makeiter(outputs)
fg = FunctionGraph(outputs=outputs, clone=False)
clients = fg.clients
Expand Down Expand Up @@ -1129,7 +1178,7 @@ def compile_pymc(
"""
# Create an update mapping of RandomVariable's RNG so that it is automatically
# updated after every function call
rng_updates = collect_default_updates(inputs, outputs)
rng_updates = collect_default_updates(inputs=inputs, outputs=outputs)

# We always reseed random variables as this provides RNGs with no chances of collision
if rng_updates:
Expand Down
44 changes: 44 additions & 0 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest
import scipy.stats as st

from pytensor import scan
from pytensor.tensor import TensorVariable

import pymc as pm
Expand Down Expand Up @@ -51,6 +52,7 @@
from pymc.logprob.abstract import get_measurable_outputs
from pymc.logprob.basic import logcdf, logp
from pymc.model import Deterministic, Model
from pymc.pytensorf import collect_default_updates
from pymc.sampling import draw, sample
from pymc.testing import (
BaseTestDistributionRandom,
Expand Down Expand Up @@ -523,6 +525,48 @@ def old_random(size):
# New API is fine
pm.CustomDist.dist(dist=old_random, class_name="custom_dist")

def test_scan(self):
def trw(nu, sigma, steps, size):
def step(xtm1, nu, sigma):
x = pm.StudentT.dist(nu=nu, mu=xtm1, sigma=sigma, shape=size)
return x, collect_default_updates([x])

xs, _ = scan(
fn=step,
outputs_info=pt.zeros(size),
non_sequences=[nu, sigma],
n_steps=steps,
)

# Logprob inference cannot be derived yet https://github.com/pymc-devs/pymc/issues/6360
# xs = swapaxes(xs, 0, -1)

return xs

nu = 4
sigma = 0.7
steps = 99
batch_size = 3
x = CustomDist.dist(nu, sigma, steps, dist=trw, size=batch_size)

x_draw = pm.draw(x, random_seed=1)
assert x_draw.shape == (steps, batch_size)
np.testing.assert_allclose(pm.draw(x, random_seed=1), x_draw)
assert not np.any(pm.draw(x, random_seed=2) == x_draw)

ref_dist = pm.RandomWalk.dist(
init_dist=pm.Flat.dist(),
innovation_dist=pm.StudentT.dist(nu=nu, sigma=sigma),
steps=steps,
size=(batch_size,),
)
ref_val = pt.concatenate([np.zeros((1, batch_size)), x_draw]).T

np.testing.assert_allclose(
pm.logp(x, x_draw).eval().sum(0),
pm.logp(ref_dist, ref_val).eval(),
)


class TestSymbolicRandomVariable:
def test_inline(self):
Expand Down
61 changes: 49 additions & 12 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pytest
import scipy.sparse as sps

from pytensor import shared
from pytensor import scan, shared
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import Variable, equal_computations
from pytensor.tensor.random.basic import normal, uniform
Expand Down Expand Up @@ -465,7 +465,7 @@ def update(self, node):
],
)(rng1, rng2)
with pytest.raises(
ValueError, match="No update mapping found for RNG used in SymbolicRandomVariable"
ValueError, match="No update found for at least one RNG used in SymbolicRandomVariable"
):
compile_pymc(inputs=[], outputs=[dummy_x1, dummy_x2])

Expand Down Expand Up @@ -531,7 +531,7 @@ def test_nested_updates(self):
next_rng2, y = pt.random.normal(rng=next_rng1).owner.outputs
next_rng3, z = pt.random.normal(rng=next_rng2).owner.outputs

collect_default_updates([], [x, y, z]) == {rng: next_rng3}
collect_default_updates(inputs=[], outputs=[x, y, z]) == {rng: next_rng3}

fn = compile_pymc([], [x, y, z], random_seed=514)
assert not set(list(np.array(fn()))) & set(list(np.array(fn())))
Expand All @@ -540,19 +540,56 @@ def test_nested_updates(self):
fn = pytensor.function([], [x, y, z], updates={rng: next_rng1})
assert set(list(np.array(fn()))) & set(list(np.array(fn())))

def test_collect_default_updates_must_be_shared(self):
shared_rng = pytensor.shared(np.random.default_rng())
nonshared_rng = shared_rng.type()

def test_collect_default_updates_must_be_shared():
shared_rng = pytensor.shared(np.random.default_rng())
nonshared_rng = shared_rng.type()
next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs
next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs

next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs
next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs
res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y])
assert res == {shared_rng: next_rng_of_shared}

res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y])
assert res == {shared_rng: next_rng_of_shared}
res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False)
assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared}

res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False)
assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared}
def test_scan_updates(self):
def step_with_update(x, rng):
next_rng, x = pm.Normal.dist(x, rng=rng).owner.outputs
return x, {rng: next_rng}

def step_wo_update(x, rng):
return step_with_update(x, rng)[0]

rng = pytensor.shared(np.random.default_rng())

xs, next_rng = scan(
fn=step_wo_update,
outputs_info=[pt.zeros(())],
non_sequences=[rng],
n_steps=10,
name="test_scan",
)

assert not next_rng

with pytest.raises(
ValueError,
match=r"No update found for at least one RNG used in Scan Op for\{cpu,test_scan\}",
):
collect_default_updates([xs])

ys, next_rng = scan(
fn=step_with_update,
outputs_info=[pt.zeros(())],
non_sequences=[rng],
n_steps=10,
)

assert collect_default_updates([ys]) == {rng: tuple(next_rng.values())[0]}

fn = compile_pymc([], ys, random_seed=1)
assert not (set(fn()) & set(fn()))


def test_replace_rng_nodes():
Expand Down

0 comments on commit 5532cb2

Please sign in to comment.