Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track value variables when lifting BroadcastTo Ops #121

Closed

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Feb 8, 2022

This PR is an alternative to #116 that enables value-variable tracking for BroadcastTo lifting rewrites. This approach allows one to specify values for broadcasted random variables that are also valued, whereas globally disabling the rewrite for valued terms would not.

The latter is most easily demonstrated by the following example:

import aesara.tensor as at

from aeppl import factorized_joint_logprob


X_rv = at.random.normal(name="X")
Z_rv = at.broadcast_to(X_rv, (2, 2))

x_vv = X_rv.clone()
z_vv = Z_rv.clone()

logp_map = factorized_joint_logprob({X_rv: x_vv, Z_rv: z_vv})

Closes #115

Also, this PR further demonstrates the relevance of #78, because it explicitly makes use of the fact that the lifting operation doesn't only apply to the valued variable, but also the variable's value.

@brandonwillard brandonwillard self-assigned this Feb 8, 2022
@brandonwillard brandonwillard added bug Something isn't working enhancement New feature or request important This label is used to indicate priority over things not given this label graph rewriting Involves the implementation of rewrites to Aesara graphs labels Feb 9, 2022
@brandonwillard brandonwillard force-pushed the naive_bcast_lift_bug branch 2 times, most recently from 459dea1 to 715b9dd Compare February 9, 2022 00:26
@codecov
Copy link

codecov bot commented Feb 9, 2022

Codecov Report

Merging #121 (8721050) into main (2019114) will increase coverage by 0.61%.
The diff coverage is 87.50%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #121      +/-   ##
==========================================
+ Coverage   94.04%   94.66%   +0.61%     
==========================================
  Files          11       11              
  Lines        1529     1536       +7     
  Branches      217      218       +1     
==========================================
+ Hits         1438     1454      +16     
+ Misses         53       40      -13     
- Partials       38       42       +4     
Impacted Files Coverage Δ
aeppl/opt.py 87.17% <87.50%> (+8.99%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 04ae9c9...8721050. Read the comment docs.

logp_map[y_vv].eval({x_vv: 0, y_vv: y_val}), st.norm(0).logpdf(y_val)
)

# Lifting should also work when `BroadcastTo`s are directly assigned value
Copy link
Contributor

@ricardoV94 ricardoV94 Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this second example below should be valid. We are measuring twice the same thing (pre-broadcast and post broadcasted value variable)

Besides, the joint_logprob would have been fine with saying {x_vv: 0, z_vv: [[1, 1], [1, 1]]} which does not match the original graph

Copy link
Contributor

@ricardoV94 ricardoV94 Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If either only one of the pre- and post- broadcasted variable is valued then #116 already worked.

... although the post-broadcast case is still a bit iffy and depends on what we mean by joint-logprob. Strictly speaking, the original generative graph of z_rv would correspond to a logp of something like:

def broadcastTo_logp(value, original_shape, dist_op, dist_params):

  # e.g., undo_broadcast((3, 3, 3), (1,)) -> (3,)
  pre_broadcast_value = undo_broadcast(value, original_shape)
  
  return switch(
    at.eq(value, at.broadcast_to(pre_broadcast_value, value.shape)),
    logprob(dist_op, pre_broadcast_value, *dist_params),
    -inf,
  )  

But that's a separate question from the comment above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If either only one of the pre- and post- broadcasted variable is valued then #116 already worked.

Only if it broadcasted something that wasn't valued, but that's the limitation I've demonstrated.

... although the post-broadcast case is still a bit iffy and depends on what we mean by joint-logprob. Strictly speaking, the original generative graph of z_rv would correspond to a logp of something like:

def broadcastTo_logp(value, original_shape, dist_op, dist_params):

  # e.g., undo_broadcast((3, 3, 3), (1,)) -> (3,)
  pre_broadcast_value = undo_broadcast(value, original_shape)
  
  return switch(
    at.eq(value, at.broadcast_to(pre_broadcast_value, value.shape)),
    logprob(dist_op, pre_broadcast_value, *dist_params),
    -inf,
  )  

But that's a separate question from the comment above

You're goint to need to clarify this, because, right now, I don't see the relevance of this function. Can you demonstrate the underlying issue by deriving an incorrect log-probability calculation under these changes?

Copy link
Member Author

@brandonwillard brandonwillard Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this second example below should be valid. We are measuring twice the same thing (pre-broadcast and post broadcasted value variable)

Are the log-probability values incorrect, or do you simply think people shouldn't be allowed to do this? Is there some other inconsistency implied by this that you can demonstrate?

Regardless, it's a minimal example, so there may be other cases that involve performing the same rewrite on a valued variable in order to derive a log-probability for another valued variable (e.g. like mixtures). Can you guarantee that this isn't possible, so that we can justify the limitations you're proposing?

Besides, the joint_logprob would have been fine with saying {x_vv: 0, z_vv: [[1, 1], [1, 1]]} which does not match the original graph

What exactly doesn't match the original graph, and how is that relevant?

if rv_map_feature is not None and rv_var in rv_map_feature.rv_values:
val_var = rv_map_feature.rv_values[rv_var]
new_val_var = at.broadcast_to(val_var, tuple(bcast_shape))
rv_map_feature.rv_values[new_bcast_out] = new_val_var
Copy link
Contributor

@ricardoV94 ricardoV94 Feb 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this add an extra logp term? If I am parsing this correctly, calling joint_logp on your first new test case would now sum 3 terms corresponding to (x_vv, new_value_var, y_vv), no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this add an extra logp term? If I am parsing this correctly, calling joint_logp on your first new test case would now sum 3 terms corresponding to (x_vv, new_value_var, y_vv), no?

Have you tried it? We can always check the sum in the tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first test case might actually add an extra term; I'll try it out.

@brandonwillard
Copy link
Member Author

brandonwillard commented Feb 9, 2022

We would need to do something about the way value variables are used in the factorized_joint_logprob logic in order to remove the extra term @ricardoV94 mentioned, but this is essentially what #78 is trying to do.

Aside from that, we can have the restriction of not computing log-probability graphs for transformed valued variables that already appear in the same call to *joint_logprob, but we should do something about how it currently handles that scenario (i.e. simply ignoring the transformed graph isn't great).

@ricardoV94
Copy link
Contributor

That seems somewhat related to #119 as well btw

@brandonwillard brandonwillard deleted the naive_bcast_lift_bug branch February 9, 2022 18:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request graph rewriting Involves the implementation of rewrites to Aesara graphs important This label is used to indicate priority over things not given this label
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Broadcasted variable not replaced in downstream terms
2 participants