-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow Minibatch of derived RVs and deprecate generators as data #7480
Conversation
mA, mB = pm.Minibatch(A, B, batch_size=10) | ||
|
||
[draw_mA, draw_mB] = pm.draw([mA, mB]) | ||
assert draw_mA.shape == (10,) | ||
np.testing.assert_allclose(draw_mA, draw_mB) | ||
|
||
# Check invalid dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was already checked in the test above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
479c4a4
to
290a643
Compare
290a643
to
49542b5
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7480 +/- ##
==========================================
- Coverage 92.16% 92.15% -0.02%
==========================================
Files 103 103
Lines 17214 17224 +10
==========================================
+ Hits 15866 15873 +7
- Misses 1348 1351 +3
|
mb_tensors = [tensor[mb_indices] for tensor in tensors] | ||
|
||
# Wrap graph in OFG so it's easily identifiable and not rewritten accidentally | ||
*mb_tensors, _ = MinibatchOp([*tensors, rng], [*mb_tensors, rng_update])(*tensors, rng) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice trick, did not know that
@@ -666,6 +672,9 @@ class GeneratorOp(Op): | |||
__props__ = ("generator",) | |||
|
|||
def __init__(self, gen, default=None): | |||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
mA, mB = pm.Minibatch(A, B, batch_size=10) | ||
|
||
[draw_mA, draw_mB] = pm.draw([mA, mB]) | ||
assert draw_mA.shape == (10,) | ||
np.testing.assert_allclose(draw_mA, draw_mB) | ||
|
||
# Check invalid dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
@@ -162,22 +159,7 @@ def fit_kwargs(inference, use_minibatch): | |||
|
|||
|
|||
def test_fit_oo(inference, fit_kwargs, simple_model_data): | |||
# Minibatch data can't be extracted into the `observed_data` group in the final InferenceData |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no more issues there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope I allow to extract the data in the idata, it extracts the whole data
minibatch_idx = minibatch_index(0, 10, size=(9,)) | ||
AD_mt = AD[minibatch_idx] | ||
TD_mt = TD[minibatch_idx] | ||
AD_mt, TD_mt = Minibatch(AD, TD, batch_size=9) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thisis much cleaner
I've created an issue to continue this work later and improve scalability of minibatches #7496 |
Description
This PR fixes issues related to Minibatch indexing reported in https://discourse.pymc.io/t/warning-using-minibatch-and-censored-together-rng-variable-has-shared-clients/14943 and extends the MinibatchRV functionality for derived RVs.
Minibatch value variables are uniquely tricky because they are random graphs, that can share RNG with other variables in the forward / logp graph. As such we need to make sure they are not mutated for the default updates to work. We tried some tricks in the past but as revealed in the discourse issue that was not enough. This PR solves the problem by encapsulating the random graph in an OpFromGraph so that the inner graph will not be touched by PyMC logprob derivation routines. It will still be inlined in the final compiled functions to avoid overhead.
I also decided to deprecate Generators as data, which showed up in some of the refactoring. The GeneratorOp is not a true Op, which should not have any side-effects. It is also not compatible with non default backends like Numba and JAX that we are moving towards to. If needed, the logic should be handled by the sampler by consuming the generator and setting the values before subsequent function calls.