Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Scans in CustomDist #6696

Merged
merged 1 commit into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ class CustomSymbolicDistRV(SymbolicRandomVariable):
def update(self, node: Node):
op = node.op
inner_updates = collect_default_updates(
op.inner_inputs, op.inner_outputs, must_be_shared=False
inputs=op.inner_inputs, outputs=op.inner_outputs, must_be_shared=False
)

# Map inner updates to outer inputs/outputs
Expand Down Expand Up @@ -668,7 +668,7 @@ def rv_op(
):
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param] + dummy_dist_params
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))

rv_type = type(
class_name,
Expand Down Expand Up @@ -713,7 +713,7 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
dummy_dist_params = [dist_param.type() for dist_param in old_dist_params]
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param] + dummy_dist_params
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
new_rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
Expand Down
63 changes: 56 additions & 7 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.scalar.basic import Cast
from pytensor.scan.op import Scan
from pytensor.tensor.basic import _as_tensor_variable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -1004,16 +1005,49 @@ def reseed_rngs(


def collect_default_updates(
inputs: Sequence[Variable],
outputs: Sequence[Variable],
*,
inputs: Optional[Sequence[Variable]] = None,
must_be_shared: bool = True,
) -> Dict[Variable, Variable]:
"""Collect default update expression for shared-variable RNGs used by RVs between inputs and outputs.

If `must_be_shared` is False, update expressions will also be returned for non-shared input RNGs.
This can be useful to obtain the symbolic update expressions from inner graphs.
"""
Parameters
----------
outputs: list of PyTensor variables
List of variables in which graphs default updates will be collected.
inputs: list of PyTensor variables, optional
Input nodes above which default updates should not be collected.
When not provided, search will include top level inputs (roots).
must_be_shared: bool, default True
Used internally by PyMC. Whether updates should be collected for non-shared
RNG input variables. This is used to collect update expressions for inner graphs.

Examples
--------
.. code:: python
import pymc as pm
from pytensor.scan import scan
from pymc.pytensorf import collect_default_updates

def scan_step(xtm1):
x = xtm1 + pm.Normal.dist()
x_update = collect_default_updates([x])
return x, x_update

x0 = pm.Normal.dist()

xs, updates = scan(
fn=scan_step,
outputs_info=[x0],
n_steps=10,
)

# PyMC makes use of the updates to seed xs properly.
# Without updates, it would raise an error.
xs_draws = pm.draw(xs, draws=10)

"""
# Avoid circular import
from pymc.distributions.distribution import SymbolicRandomVariable

Expand Down Expand Up @@ -1048,16 +1082,31 @@ def find_default_update(clients, rng: Variable) -> Union[None, Variable]:
next_rng = client.op.update(client).get(rng)
if next_rng is None:
raise ValueError(
f"No update mapping found for RNG used in SymbolicRandomVariable Op {client.op}"
f"No update found for at least one RNG used in SymbolicRandomVariable Op {client.op}"
)
elif isinstance(client.op, Scan):
# Check if any shared output corresponds to the RNG
rng_idx = client.inputs.index(rng)
io_map = client.op.get_oinp_iinp_iout_oout_mappings()["outer_out_from_outer_inp"]
out_idx = io_map.get(rng_idx, -1)
if out_idx != -1:
next_rng = client.outputs[out_idx]
else: # No break
raise ValueError(
f"No update found for at least one RNG used in Scan Op {client.op}.\n"
"You can use `pytensorf.collect_default_updates` inside the Scan function to return updates automatically."
)
else:
# We don't know how this RNG should be updated (e.g., Scan).
# We don't know how this RNG should be updated (e.g., OpFromGraph).
# The user should provide an update manually
return None

# Recurse until we find final update for RNG
return find_default_update(clients, next_rng)

if inputs is None:
inputs = []

outputs = makeiter(outputs)
fg = FunctionGraph(outputs=outputs, clone=False)
clients = fg.clients
Expand Down Expand Up @@ -1129,7 +1178,7 @@ def compile_pymc(
"""
# Create an update mapping of RandomVariable's RNG so that it is automatically
# updated after every function call
rng_updates = collect_default_updates(inputs, outputs)
rng_updates = collect_default_updates(inputs=inputs, outputs=outputs)

# We always reseed random variables as this provides RNGs with no chances of collision
if rng_updates:
Expand Down
44 changes: 44 additions & 0 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest
import scipy.stats as st

from pytensor import scan
from pytensor.tensor import TensorVariable

import pymc as pm
Expand Down Expand Up @@ -51,6 +52,7 @@
from pymc.logprob.abstract import get_measurable_outputs
from pymc.logprob.basic import logcdf, logp
from pymc.model import Deterministic, Model
from pymc.pytensorf import collect_default_updates
from pymc.sampling import draw, sample
from pymc.testing import (
BaseTestDistributionRandom,
Expand Down Expand Up @@ -523,6 +525,48 @@ def old_random(size):
# New API is fine
pm.CustomDist.dist(dist=old_random, class_name="custom_dist")

def test_scan(self):
def trw(nu, sigma, steps, size):
def step(xtm1, nu, sigma):
x = pm.StudentT.dist(nu=nu, mu=xtm1, sigma=sigma, shape=size)
return x, collect_default_updates([x])

xs, _ = scan(
fn=step,
outputs_info=pt.zeros(size),
non_sequences=[nu, sigma],
n_steps=steps,
)

# Logprob inference cannot be derived yet https://github.com/pymc-devs/pymc/issues/6360
# xs = swapaxes(xs, 0, -1)

return xs

nu = 4
sigma = 0.7
steps = 99
batch_size = 3
x = CustomDist.dist(nu, sigma, steps, dist=trw, size=batch_size)

x_draw = pm.draw(x, random_seed=1)
assert x_draw.shape == (steps, batch_size)
np.testing.assert_allclose(pm.draw(x, random_seed=1), x_draw)
assert not np.any(pm.draw(x, random_seed=2) == x_draw)

ref_dist = pm.RandomWalk.dist(
init_dist=pm.Flat.dist(),
innovation_dist=pm.StudentT.dist(nu=nu, sigma=sigma),
steps=steps,
size=(batch_size,),
)
ref_val = pt.concatenate([np.zeros((1, batch_size)), x_draw]).T

np.testing.assert_allclose(
pm.logp(x, x_draw).eval().sum(0),
pm.logp(ref_dist, ref_val).eval(),
)


class TestSymbolicRandomVariable:
def test_inline(self):
Expand Down
61 changes: 49 additions & 12 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pytest
import scipy.sparse as sps

from pytensor import shared
from pytensor import scan, shared
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import Variable, equal_computations
from pytensor.tensor.random.basic import normal, uniform
Expand Down Expand Up @@ -465,7 +465,7 @@ def update(self, node):
],
)(rng1, rng2)
with pytest.raises(
ValueError, match="No update mapping found for RNG used in SymbolicRandomVariable"
ValueError, match="No update found for at least one RNG used in SymbolicRandomVariable"
):
compile_pymc(inputs=[], outputs=[dummy_x1, dummy_x2])

Expand Down Expand Up @@ -531,7 +531,7 @@ def test_nested_updates(self):
next_rng2, y = pt.random.normal(rng=next_rng1).owner.outputs
next_rng3, z = pt.random.normal(rng=next_rng2).owner.outputs

collect_default_updates([], [x, y, z]) == {rng: next_rng3}
collect_default_updates(inputs=[], outputs=[x, y, z]) == {rng: next_rng3}

fn = compile_pymc([], [x, y, z], random_seed=514)
assert not set(list(np.array(fn()))) & set(list(np.array(fn())))
Expand All @@ -540,19 +540,56 @@ def test_nested_updates(self):
fn = pytensor.function([], [x, y, z], updates={rng: next_rng1})
assert set(list(np.array(fn()))) & set(list(np.array(fn())))

def test_collect_default_updates_must_be_shared(self):
shared_rng = pytensor.shared(np.random.default_rng())
nonshared_rng = shared_rng.type()

def test_collect_default_updates_must_be_shared():
shared_rng = pytensor.shared(np.random.default_rng())
nonshared_rng = shared_rng.type()
next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs
next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs

next_rng_of_shared, x = pt.random.normal(rng=shared_rng).owner.outputs
next_rng_of_nonshared, y = pt.random.normal(rng=nonshared_rng).owner.outputs
res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y])
assert res == {shared_rng: next_rng_of_shared}

res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y])
assert res == {shared_rng: next_rng_of_shared}
res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False)
assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared}

res = collect_default_updates(inputs=[nonshared_rng], outputs=[x, y], must_be_shared=False)
assert res == {shared_rng: next_rng_of_shared, nonshared_rng: next_rng_of_nonshared}
def test_scan_updates(self):
def step_with_update(x, rng):
next_rng, x = pm.Normal.dist(x, rng=rng).owner.outputs
return x, {rng: next_rng}

def step_wo_update(x, rng):
return step_with_update(x, rng)[0]

rng = pytensor.shared(np.random.default_rng())

xs, next_rng = scan(
fn=step_wo_update,
outputs_info=[pt.zeros(())],
non_sequences=[rng],
n_steps=10,
name="test_scan",
)

assert not next_rng

with pytest.raises(
ValueError,
match=r"No update found for at least one RNG used in Scan Op for\{cpu,test_scan\}",
):
collect_default_updates([xs])

ys, next_rng = scan(
fn=step_with_update,
outputs_info=[pt.zeros(())],
non_sequences=[rng],
n_steps=10,
)

assert collect_default_updates([ys]) == {rng: tuple(next_rng.values())[0]}

fn = compile_pymc([], ys, random_seed=1)
assert not (set(fn()) & set(fn()))


def test_replace_rng_nodes():
Expand Down