From 3483eb5355a0b8be7a55b7f90cc4956d395c5a8f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 23 May 2023 18:04:26 +0200 Subject: [PATCH 1/2] Avoid cloning of Minibatch values --- pymc/data.py | 6 +++++ pymc/logprob/basic.py | 22 +++++++++++++---- pymc/logprob/rewriting.py | 6 ++--- tests/variational/test_minibatch_rv.py | 34 ++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 8 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index 71ca2439faa..b8db7dba4b5 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -36,6 +36,7 @@ import pymc as pm +from pymc.logprob.abstract import _get_measurable_outputs from pymc.pytensorf import convert_observed_data __all__ = [ @@ -134,6 +135,11 @@ def make_node(self, rng, *args, **kwargs): return super().make_node(rng, *args, **kwargs) +@_get_measurable_outputs.register(MinibatchIndexRV) +def minibatch_index_rv_measuarable_outputs(op, node): + return [] + + minibatch_index = MinibatchIndexRV() diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index a8d4221f060..75ec51ba198 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -44,7 +44,13 @@ import pytensor.tensor as pt from pytensor import config -from pytensor.graph.basic import Variable, graph_inputs, io_toposort +from pytensor.graph.basic import ( + Constant, + Variable, + ancestors, + graph_inputs, + io_toposort, +) from pytensor.graph.op import compute_test_value from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter from pytensor.tensor.random.op import RandomVariable @@ -231,10 +237,16 @@ def factorized_joint_logprob( # node. replacements = updated_rv_values.copy() - # To avoid cloning the value variables, we map them to themselves in the - # `replacements` `dict` (i.e. entries already existing in `replacements` - # aren't cloned) - replacements.update({v: v for v in rv_values.values()}) + # To avoid cloning the value variables (or ancestors of value variables), + # we map them to themselves in the `replacements` `dict` + # (i.e. entries already existing in `replacements` aren't cloned) + replacements.update( + { + v: v + for v in ancestors(rv_values.values()) + if (not isinstance(v, Constant) and v not in replacements) + } + ) # Walk the graph from its inputs to its outputs and construct the # log-probability diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index 70279efda59..a721e12f914 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -39,7 +39,7 @@ import pytensor.tensor as pt from pytensor.compile.mode import optdb -from pytensor.graph.basic import Variable +from pytensor.graph.basic import Constant, Variable, ancestors from pytensor.graph.features import Feature from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import GraphRewriter, node_rewriter @@ -316,8 +316,8 @@ def construct_ir_fgraph( # the old nodes to the new ones; otherwise, we won't be able to use # `rv_values`. # We start the `dict` with mappings from the value variables to themselves, - # to prevent them from being cloned. - memo = {v: v for v in rv_values.values()} + # to prevent them from being cloned. This also includes ancestors + memo = {v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)} # We add `ShapeFeature` because it will get rid of references to the old # `RandomVariable`s that have been lifted; otherwise, it will be difficult diff --git a/tests/variational/test_minibatch_rv.py b/tests/variational/test_minibatch_rv.py index 8246c16ca39..87df8c7228b 100644 --- a/tests/variational/test_minibatch_rv.py +++ b/tests/variational/test_minibatch_rv.py @@ -20,6 +20,7 @@ import pymc as pm from pymc import Normal, draw +from pymc.data import minibatch_index from pymc.testing import select_by_precision from pymc.variational.minibatch_rv import create_minibatch_rv from tests.test_data import gen1, gen2 @@ -155,3 +156,36 @@ def test_random(self): mx = create_minibatch_rv(x, total_size=(10,)) assert mx is not x np.testing.assert_array_equal(draw(mx, random_seed=1), draw(x, random_seed=1)) + + @pytest.mark.filterwarnings("error") + def test_minibatch_parameter_and_value(self): + rng = np.random.default_rng(161) + total_size = 1000 + + with pm.Model(check_bounds=False) as m: + AD = pm.MutableData("AD", np.arange(total_size, dtype="float64")) + TD = pm.MutableData("TD", np.arange(total_size, dtype="float64")) + + minibatch_idx = minibatch_index(0, 10, size=(9,)) + AD_mt = AD[minibatch_idx] + TD_mt = TD[minibatch_idx] + + pm.Normal( + "AD_predicted", + mu=TD_mt, + observed=AD_mt, + total_size=1000, + ) + + logp_fn = m.compile_logp() + + ip = m.initial_point() + np.testing.assert_allclose(logp_fn(ip), st.norm.logpdf(0) * 1000) + + with m: + pm.set_data({"AD": np.arange(total_size) + 1}) + np.testing.assert_allclose(logp_fn(ip), st.norm.logpdf(1) * 1000) + + with m: + pm.set_data({"AD": rng.normal(size=1000)}) + assert logp_fn(ip) != logp_fn(ip) From fc9d72fc5c5318041b408ca4ccb8de9224c9136a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 23 May 2023 19:20:56 +0200 Subject: [PATCH 2/2] Fix Minibatch for multiple variables --- pymc/data.py | 9 +++++---- tests/test_data.py | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index b8db7dba4b5..7f484ae9129 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -170,13 +170,11 @@ def assert_all_scalars_equal(scalar, *scalars): else: return Assert( "All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code" - )(scalar, pt.all([scalar == s for s in scalars])) + )(scalar, pt.all([pt.eq(scalar, s) for s in scalars])) def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: int): - """ - Get random slices from variables from the leading dimension. - + """Get random slices from variables from the leading dimension. Parameters ---------- @@ -191,6 +189,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: >>> mdata1, mdata2 = Minibatch(data1, data2, batch_size=10) """ + if not isinstance(batch_size, int): + raise TypeError("batch_size must be an integer") + tensor, *tensors = tuple(map(pt.as_tensor, (variable, *variables))) upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)]) slc = minibatch_index(0, upper, size=batch_size) diff --git a/tests/test_data.py b/tests/test_data.py index 6db3b508759..0d004cd0a1b 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -14,6 +14,7 @@ import io import itertools as it +import re import cloudpickle import numpy as np @@ -614,3 +615,23 @@ def test_assert(self): ): d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20) d1.eval() + + def test_multiple_vars(self): + A = np.arange(1000) + B = np.arange(1000) + mA, mB = pm.Minibatch(A, B, batch_size=10) + + [draw_mA, draw_mB] = pm.draw([mA, mB]) + assert draw_mA.shape == (10,) + np.testing.assert_allclose(draw_mA, draw_mB) + + # Check invalid dims + A = np.arange(1000) + C = np.arange(999) + mA, mC = pm.Minibatch(A, C, batch_size=10) + + with pytest.raises( + AssertionError, + match=re.escape("All variables shape[0] in Minibatch should be equal"), + ): + pm.draw([mA, mC])