-
-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Track value variables when lifting BroadcastTo
Op
s
#121
Track value variables when lifting BroadcastTo
Op
s
#121
Conversation
26a2ebb
to
4503917
Compare
459dea1
to
715b9dd
Compare
715b9dd
to
8721050
Compare
Codecov Report
@@ Coverage Diff @@
## main #121 +/- ##
==========================================
+ Coverage 94.04% 94.66% +0.61%
==========================================
Files 11 11
Lines 1529 1536 +7
Branches 217 218 +1
==========================================
+ Hits 1438 1454 +16
+ Misses 53 40 -13
- Partials 38 42 +4
Continue to review full report at Codecov.
|
logp_map[y_vv].eval({x_vv: 0, y_vv: y_val}), st.norm(0).logpdf(y_val) | ||
) | ||
|
||
# Lifting should also work when `BroadcastTo`s are directly assigned value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this second example below should be valid. We are measuring twice the same thing (pre-broadcast and post broadcasted value variable)
Besides, the joint_logprob
would have been fine with saying {x_vv: 0, z_vv: [[1, 1], [1, 1]]}
which does not match the original graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If either only one of the pre- and post- broadcasted variable is valued then #116 already worked.
... although the post-broadcast case is still a bit iffy and depends on what we mean by joint-logprob. Strictly speaking, the original generative graph of z_rv
would correspond to a logp of something like:
def broadcastTo_logp(value, original_shape, dist_op, dist_params):
# e.g., undo_broadcast((3, 3, 3), (1,)) -> (3,)
pre_broadcast_value = undo_broadcast(value, original_shape)
return switch(
at.eq(value, at.broadcast_to(pre_broadcast_value, value.shape)),
logprob(dist_op, pre_broadcast_value, *dist_params),
-inf,
)
But that's a separate question from the comment above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If either only one of the pre- and post- broadcasted variable is valued then #116 already worked.
Only if it broadcasted something that wasn't valued, but that's the limitation I've demonstrated.
... although the post-broadcast case is still a bit iffy and depends on what we mean by joint-logprob. Strictly speaking, the original generative graph of
z_rv
would correspond to a logp of something like:def broadcastTo_logp(value, original_shape, dist_op, dist_params): # e.g., undo_broadcast((3, 3, 3), (1,)) -> (3,) pre_broadcast_value = undo_broadcast(value, original_shape) return switch( at.eq(value, at.broadcast_to(pre_broadcast_value, value.shape)), logprob(dist_op, pre_broadcast_value, *dist_params), -inf, )But that's a separate question from the comment above
You're goint to need to clarify this, because, right now, I don't see the relevance of this function. Can you demonstrate the underlying issue by deriving an incorrect log-probability calculation under these changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this second example below should be valid. We are measuring twice the same thing (pre-broadcast and post broadcasted value variable)
Are the log-probability values incorrect, or do you simply think people shouldn't be allowed to do this? Is there some other inconsistency implied by this that you can demonstrate?
Regardless, it's a minimal example, so there may be other cases that involve performing the same rewrite on a valued variable in order to derive a log-probability for another valued variable (e.g. like mixtures). Can you guarantee that this isn't possible, so that we can justify the limitations you're proposing?
Besides, the
joint_logprob
would have been fine with saying{x_vv: 0, z_vv: [[1, 1], [1, 1]]}
which does not match the original graph
What exactly doesn't match the original graph, and how is that relevant?
if rv_map_feature is not None and rv_var in rv_map_feature.rv_values: | ||
val_var = rv_map_feature.rv_values[rv_var] | ||
new_val_var = at.broadcast_to(val_var, tuple(bcast_shape)) | ||
rv_map_feature.rv_values[new_bcast_out] = new_val_var |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this add an extra logp term? If I am parsing this correctly, calling joint_logp
on your first new test case would now sum 3 terms corresponding to (x_vv, new_value_var, y_vv)
, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this add an extra logp term? If I am parsing this correctly, calling
joint_logp
on your first new test case would now sum 3 terms corresponding to(x_vv, new_value_var, y_vv)
, no?
Have you tried it? We can always check the sum in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first test case might actually add an extra term; I'll try it out.
We would need to do something about the way value variables are used in the Aside from that, we can have the restriction of not computing log-probability graphs for transformed valued variables that already appear in the same call to |
That seems somewhat related to #119 as well btw |
This PR is an alternative to #116 that enables value-variable tracking for
BroadcastTo
lifting rewrites. This approach allows one to specify values for broadcasted random variables that are also valued, whereas globally disabling the rewrite for valued terms would not.The latter is most easily demonstrated by the following example:
Closes #115
Also, this PR further demonstrates the relevance of #78, because it explicitly makes use of the fact that the lifting operation doesn't only apply to the valued variable, but also the variable's value.