Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix minibatch bugs #6730

Merged
merged 2 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import pymc as pm

from pymc.logprob.abstract import _get_measurable_outputs
from pymc.pytensorf import convert_observed_data

__all__ = [
Expand Down Expand Up @@ -134,6 +135,11 @@ def make_node(self, rng, *args, **kwargs):
return super().make_node(rng, *args, **kwargs)


@_get_measurable_outputs.register(MinibatchIndexRV)
def minibatch_index_rv_measuarable_outputs(op, node):
return []


minibatch_index = MinibatchIndexRV()


Expand Down Expand Up @@ -164,13 +170,11 @@ def assert_all_scalars_equal(scalar, *scalars):
else:
return Assert(
"All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code"
)(scalar, pt.all([scalar == s for s in scalars]))
)(scalar, pt.all([pt.eq(scalar, s) for s in scalars]))


def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: int):
"""
Get random slices from variables from the leading dimension.

"""Get random slices from variables from the leading dimension.

Parameters
----------
Expand All @@ -185,6 +189,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
>>> mdata1, mdata2 = Minibatch(data1, data2, batch_size=10)
"""

if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")

tensor, *tensors = tuple(map(pt.as_tensor, (variable, *variables)))
upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)])
slc = minibatch_index(0, upper, size=batch_size)
Expand Down
22 changes: 17 additions & 5 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@
import pytensor.tensor as pt

from pytensor import config
from pytensor.graph.basic import Variable, graph_inputs, io_toposort
from pytensor.graph.basic import (
Constant,
Variable,
ancestors,
graph_inputs,
io_toposort,
)
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -231,10 +237,16 @@ def factorized_joint_logprob(
# node.
replacements = updated_rv_values.copy()

# To avoid cloning the value variables, we map them to themselves in the
# `replacements` `dict` (i.e. entries already existing in `replacements`
# aren't cloned)
replacements.update({v: v for v in rv_values.values()})
# To avoid cloning the value variables (or ancestors of value variables),
# we map them to themselves in the `replacements` `dict`
# (i.e. entries already existing in `replacements` aren't cloned)
replacements.update(
{
v: v
for v in ancestors(rv_values.values())
if (not isinstance(v, Constant) and v not in replacements)
}
)

# Walk the graph from its inputs to its outputs and construct the
# log-probability
Expand Down
6 changes: 3 additions & 3 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import pytensor.tensor as pt

from pytensor.compile.mode import optdb
from pytensor.graph.basic import Variable
from pytensor.graph.basic import Constant, Variable, ancestors
from pytensor.graph.features import Feature
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import GraphRewriter, node_rewriter
Expand Down Expand Up @@ -316,8 +316,8 @@ def construct_ir_fgraph(
# the old nodes to the new ones; otherwise, we won't be able to use
# `rv_values`.
# We start the `dict` with mappings from the value variables to themselves,
# to prevent them from being cloned.
memo = {v: v for v in rv_values.values()}
# to prevent them from being cloned. This also includes ancestors
memo = {v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)}

# We add `ShapeFeature` because it will get rid of references to the old
# `RandomVariable`s that have been lifted; otherwise, it will be difficult
Expand Down
21 changes: 21 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import io
import itertools as it
import re

import cloudpickle
import numpy as np
Expand Down Expand Up @@ -614,3 +615,23 @@ def test_assert(self):
):
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
d1.eval()

def test_multiple_vars(self):
A = np.arange(1000)
B = np.arange(1000)
mA, mB = pm.Minibatch(A, B, batch_size=10)

[draw_mA, draw_mB] = pm.draw([mA, mB])
assert draw_mA.shape == (10,)
np.testing.assert_allclose(draw_mA, draw_mB)

# Check invalid dims
A = np.arange(1000)
C = np.arange(999)
mA, mC = pm.Minibatch(A, C, batch_size=10)

with pytest.raises(
AssertionError,
match=re.escape("All variables shape[0] in Minibatch should be equal"),
):
pm.draw([mA, mC])
34 changes: 34 additions & 0 deletions tests/variational/test_minibatch_rv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pymc as pm

from pymc import Normal, draw
from pymc.data import minibatch_index
from pymc.testing import select_by_precision
from pymc.variational.minibatch_rv import create_minibatch_rv
from tests.test_data import gen1, gen2
Expand Down Expand Up @@ -155,3 +156,36 @@ def test_random(self):
mx = create_minibatch_rv(x, total_size=(10,))
assert mx is not x
np.testing.assert_array_equal(draw(mx, random_seed=1), draw(x, random_seed=1))

@pytest.mark.filterwarnings("error")
def test_minibatch_parameter_and_value(self):
rng = np.random.default_rng(161)
total_size = 1000

with pm.Model(check_bounds=False) as m:
AD = pm.MutableData("AD", np.arange(total_size, dtype="float64"))
TD = pm.MutableData("TD", np.arange(total_size, dtype="float64"))

minibatch_idx = minibatch_index(0, 10, size=(9,))
AD_mt = AD[minibatch_idx]
TD_mt = TD[minibatch_idx]

pm.Normal(
"AD_predicted",
mu=TD_mt,
observed=AD_mt,
total_size=1000,
)

logp_fn = m.compile_logp()

ip = m.initial_point()
np.testing.assert_allclose(logp_fn(ip), st.norm.logpdf(0) * 1000)

with m:
pm.set_data({"AD": np.arange(total_size) + 1})
np.testing.assert_allclose(logp_fn(ip), st.norm.logpdf(1) * 1000)

with m:
pm.set_data({"AD": rng.normal(size=1000)})
assert logp_fn(ip) != logp_fn(ip)