-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow Minibatch of derived RVs and deprecate generators as data #7480
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -156,19 +156,25 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: | |
TypeError | ||
|
||
""" | ||
# TODO: These data functions should be in data.py or model/core.py | ||
from pymc.data import MinibatchOp | ||
|
||
if isinstance(x, Constant): | ||
return x.data | ||
if isinstance(x, SharedVariable): | ||
return x.get_value() | ||
if x.owner and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast): | ||
array_data = extract_obs_data(x.owner.inputs[0]) | ||
return array_data.astype(x.type.dtype) | ||
if x.owner and isinstance(x.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1): | ||
array_data = extract_obs_data(x.owner.inputs[0]) | ||
mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:]) | ||
mask = np.zeros_like(array_data) | ||
mask[mask_idx] = 1 | ||
return np.ma.MaskedArray(array_data, mask) | ||
if x.owner is not None: | ||
if isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast): | ||
array_data = extract_obs_data(x.owner.inputs[0]) | ||
return array_data.astype(x.type.dtype) | ||
if isinstance(x.owner.op, MinibatchOp): | ||
return extract_obs_data(x.owner.inputs[x.owner.outputs.index(x)]) | ||
if isinstance(x.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1): | ||
array_data = extract_obs_data(x.owner.inputs[0]) | ||
mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:]) | ||
mask = np.zeros_like(array_data) | ||
mask[mask_idx] = 1 | ||
return np.ma.MaskedArray(array_data, mask) | ||
|
||
raise TypeError(f"Data cannot be extracted from {x}") | ||
|
||
|
@@ -666,6 +672,9 @@ class GeneratorOp(Op): | |
__props__ = ("generator",) | ||
|
||
def __init__(self, gen, default=None): | ||
warnings.warn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lgtm |
||
"generator data is deprecated and will be removed in a future release", FutureWarning | ||
) | ||
from pymc.data import GeneratorAdapter | ||
|
||
super().__init__() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,6 @@ | |
|
||
import io | ||
import itertools as it | ||
import re | ||
|
||
from os import path | ||
|
||
|
@@ -29,7 +28,7 @@ | |
|
||
import pymc as pm | ||
|
||
from pymc.data import is_minibatch | ||
from pymc.data import MinibatchOp | ||
from pymc.pytensorf import GeneratorOp, floatX | ||
|
||
|
||
|
@@ -593,44 +592,34 @@ class TestMinibatch: | |
|
||
def test_1d(self): | ||
mb = pm.Minibatch(self.data, batch_size=20) | ||
assert is_minibatch(mb) | ||
assert mb.eval().shape == (20, 10) | ||
assert isinstance(mb.owner.op, MinibatchOp) | ||
draw1, draw2 = pm.draw(mb, draws=2) | ||
assert draw1.shape == (20, 10) | ||
assert draw2.shape == (20, 10) | ||
assert not np.all(draw1 == draw2) | ||
|
||
def test_allowed(self): | ||
mb = pm.Minibatch(pt.as_tensor(self.data).astype(int), batch_size=20) | ||
assert is_minibatch(mb) | ||
assert isinstance(mb.owner.op, MinibatchOp) | ||
|
||
def test_not_allowed(self): | ||
with pytest.raises(ValueError, match="not valid for Minibatch"): | ||
mb = pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20) | ||
pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20) | ||
|
||
def test_not_allowed2(self): | ||
with pytest.raises(ValueError, match="not valid for Minibatch"): | ||
mb = pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20) | ||
pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20) | ||
|
||
def test_assert(self): | ||
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20) | ||
with pytest.raises( | ||
AssertionError, match=r"All variables shape\[0\] in Minibatch should be equal" | ||
): | ||
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20) | ||
d1.eval() | ||
|
||
def test_multiple_vars(self): | ||
A = np.arange(1000) | ||
B = np.arange(1000) | ||
B = -np.arange(1000) | ||
mA, mB = pm.Minibatch(A, B, batch_size=10) | ||
|
||
[draw_mA, draw_mB] = pm.draw([mA, mB]) | ||
assert draw_mA.shape == (10,) | ||
np.testing.assert_allclose(draw_mA, draw_mB) | ||
|
||
# Check invalid dims | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was already checked in the test above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
A = np.arange(1000) | ||
C = np.arange(999) | ||
mA, mC = pm.Minibatch(A, C, batch_size=10) | ||
|
||
with pytest.raises( | ||
AssertionError, | ||
match=re.escape("All variables shape[0] in Minibatch should be equal"), | ||
): | ||
pm.draw([mA, mC]) | ||
np.testing.assert_allclose(draw_mA, -draw_mB) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice trick, did not know that