Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested logprob rewrites proof of concept #102

Closed
wants to merge 2 commits into from

Conversation

ricardoV94
Copy link
Contributor

@ricardoV94 ricardoV94 commented Jan 3, 2022

This PR is a proof of concept for how we could handle composite RVs that require nested aeppl rewrites. The goal is to figure out if this approach is too limited / cumbersome and, if so, what could work better.

Roughly, most of our rewrites come in two parts:

  1. Check if an initially non-measurable variable was assigned to a value variable
  2. If so, try to convert this variable into an temporary measurable variable for which _logprob can be safely dispatched

In step 2, we often need inputs that are themselves measurable variables, but which don't have value variables assigned to them:

aeppl/aeppl/mixture.py

Lines 283 to 287 in 05d0c68

for i, component_rv in enumerate(mixture_res):
if component_rv in rv_map_feature.rv_values:
raise ValueError(
f"A value variable was specified for a mixture component: {component_rv}"
)

If these inputs are RandomVariables things are fine, as those are measurable by default. Otherwise, we might still be dealing with things that could be measurable... but we will never know because they don't (and shouldn't) have value variables and their corresponding rewrites won't be triggered due to step 1.

The illustrating example in this PR is a scalar mixture that has a clipped variable as one of it's components:

    x = at.clip(at.random.normal(loc=1), 0.5, 1.5); x.name = "x"
    y = at.random.beta(1, 2, size=None, name="y")

    comps = at.stack(x, y); comps.name = "comps"
    idxs = at.random.bernoulli(0.4, size=None, name="idxs")
    mix = comps[idxs]; mix.name = "mix"

    mix_vv = mix.clone()
    idxs_vv = idxs.clone()
    logp = joint_logprob({idxs: idxs_vv, mix: mix_vv})

I have also added a test example with nested scalar mixtures

Suggestion

The suggestion is that any rewrite that depends on other inputs being measurable, can assign a temporary value variable UPSTREAM_VALUE to those inputs. This value variable is automatically discarded by the PreserveRVMappings when such variables are converted to measurable ones by their own rewrites.

Limitations

This would work for some current and future rewrites like #26, but not for everything.

For example nested rewrites that depend on the direct manipulation / creation of value variables such as inc_subtensor and scans would not work. This would also not work when the _logprob dispatched function depends on having RandomVariables as inputs, as happens with non-scalar mixtures:

rv_m = rv_pull_down(rv[m_indices])

Also depending on the order with which rewrites are called this may not always work, because the attribution of the UPSTREAM_VALUE is not a visible graph change, and I guess the Equilibrium optimization will stop if after passing through all nodes, none is changed

@ricardoV94 ricardoV94 added enhancement New feature or request request discussion labels Jan 3, 2022
@codecov
Copy link

codecov bot commented Jan 3, 2022

Codecov Report

Merging #102 (82c8bff) into main (73522bd) will increase coverage by 0.02%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #102      +/-   ##
==========================================
+ Coverage   93.89%   93.92%   +0.02%     
==========================================
  Files          10       10              
  Lines        1426     1432       +6     
  Branches      210      212       +2     
==========================================
+ Hits         1339     1345       +6     
  Misses         50       50              
  Partials       37       37              
Impacted Files Coverage Δ
aeppl/mixture.py 97.35% <100.00%> (+0.05%) ⬆️
aeppl/opt.py 77.01% <100.00%> (+0.82%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 73522bd...82c8bff. Read the comment docs.

@ricardoV94 ricardoV94 force-pushed the composite_logprob_poc branch from 05d0c68 to 0595a76 Compare January 3, 2022 11:46
@brandonwillard
Copy link
Member

Roughly, most of our rewrites can be divided in two parts:

  1. Check if an initially non-measurable variable was assigned to a value variable
  2. If so, try to convert this variable into an temporary measurable variable for which _logprob can be safely dispatched

It's better to think of this process as first producing a "measurable" intermediate representation (IR) of the original model graph. The fact that the value variable checking/manipulation is also done at that point is mostly a matter of convenience.

This distinction is important, because there may be use/need for the identification of measurable terms independent of value variables, to which it seems you're ultimately alluding.

It's possible that the entire scenario you're describing could be handled by simply loosening the overly strict value variable requirements instead of adding a new placeholder element and associated logic.

In other words, why work around/within existing restrictions when they might be unnecessary in the first place? If we don't consider these kinds of things first and foremost, we'll likely end up building something that becomes more and more difficult to develop.

@brandonwillard
Copy link
Member

brandonwillard commented Jan 3, 2022

Also depending on the order with which rewrites are called this may not always work, because the attribution of the UPSTREAM_VALUE is not a visible graph change, and I guess the Equilibrium optimization will stop if after passing through all nodes, none is changed

#78 is an answer (of sorts) to this. The idea is to put value variables in the graph, then we don't need a Feature, and the rewrites we're using make a lot more sense at a higher level, because they're operating directly upon a measurable variable, observed value pair.

@ricardoV94 ricardoV94 changed the title Composite logprob proof of concept Nested logprob proof of concept Jan 3, 2022
@brandonwillard brandonwillard marked this pull request as draft January 3, 2022 18:33
@ricardoV94
Copy link
Contributor Author

It's possible that the entire scenario you're describing could be handled by simply loosening the overly strict value variable requirements instead of adding a new placeholder element and associated logic.

In other words, why work around/within existing restrictions when they might be unnecessary in the first place? If we don't consider these kinds of things first and foremost, we'll likely end up building something that becomes more and more difficult to develop.

That was my other alternative, just eagerly convert / flag all nodes that can be made measurable as such, but it seemed like this would be a bit insane/expensive the more I thought about it.

@ricardoV94 ricardoV94 changed the title Nested logprob proof of concept Nested logprob rewrites proof of concept Jan 3, 2022
@brandonwillard
Copy link
Member

That was my other alternative, just eagerly convert / flag all nodes that can be made measurable as such, but it seemed like this would be a bit insane/expensive the more I thought about it.

What made you think that?

@ricardoV94
Copy link
Contributor Author

In my mind it meant we would end up replacing almost all nodes in a graph by equivalent "measurable" ones, specially as we increase the "coverage" for the basic Aesara Ops.

@brandonwillard
Copy link
Member

In my mind it meant we would end up replacing almost all nodes in a graph by equivalent "measurable" ones, specially as we increase the "coverage" for the basic Aesara Ops.

Yes, how else would we do these kinds of things? In other words, if a mixture needs to consist of Joined MeasurableVariables, how are we going to avoid determining whether or not the mixture components are MeasurableVariables in the first place? Since the latter requires that the IR rewrites be performed first, there's really no avoiding a broad application of them throughout the sub-graphs within the "valued" mixture node.

The value variables are used to determine the subgraphs to which we apply IR rewrites in the first place (i.e. they determine the conditional log-probabilities), and those subgraphs will be the outputs of the FunctionGraph being operated upon, so we already have a means of applying rewrites to only the relevant sub-graphs.

One could make a distinction between taking a bottom-up and top-down approach, but I don't think there's a meaningful difference between the two in this scenario. We already know that the "bottom"/outputs of the FunctionGraphs need to be measurable—otherwise they wouldn't be there—and, in all "composite" node cases (e.g. Mixture and Scan) I can imagine right now, we still need the rewrites to make specific sub-graphs measurable.

Even so, we always have control over how the rewrites traverse the graphs, so there's nothing preventing us from applying them more intelligently at a high level and avoiding unnecessary applications.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Jan 4, 2022

The value variables are used to determine the subgraphs to which we apply IR rewrites in the first place (i.e. they determine the conditional log-probabilities), and those subgraphs will be the outputs of the FunctionGraph being operated upon, so we already have a means of applying rewrites to only the relevant sub-graphs.

I don't quite see how do we distinguish between the following two cases (if we have to):

x = at.random.normal()
y = at.clip(x, 0, 1)
z = at.random.normal(y)

x_vv = x.clone()
y_vv = y.clone()
z_vv = z.clone()

# not measuring clip node
joint_logprob({x: x_vv, z: z_vv})

# measuring the clip node
joint_logprob({y: y_vv, z: z_vv})

In both cases y is part of the FunctionGraph but only in the second do we need the rewrite.

Currently we distinguish between the two based on whether it has a value variable or not, but if I understand you correctly, we would now always rewrite a clip so that it would work in nested expressions like the mixture example above.


Alternatively, there is some information already that we don't want to convert it in the first case because the measurement input from which the logprob derives, x, is already valued / accounted for. But this seems like a similar type of local information passing just in the other direction. Also this conflicts with a third case currently (related to #85):

# random logprob based on x and a deterministic clipping
joint_logprob({z: z_vv})

@brandonwillard
Copy link
Member

brandonwillard commented Jan 5, 2022

Currently we distinguish between the two based on whether it has a value variable or not, but if I understand you correctly, we would now always rewrite a clip so that it would work in nested expressions like the mixture example above.

Yes, we're currently performing (IR) normalizing/canonicalizing rewrites to the entire graph, even if there are some "upstream" terms that may be irrelevant. This is the case in the former example, where the measurability of the y term isn't relevant and the rewrites for clip aren't needed.

That's done because it's (generally) cheap and it provides a simplicity to our rewrites and the way in which they're applied. It's basically a "greedy" approach.

As I mentioned above, we can always control the graph traversal and application of rewrites so that certain rewrites are only applied upstream from a valued variable when they're needed; however, complications arise in this scenario due to an added dependency between the traversal logic and the rewrite rules themselves.

For instance, starting at the "bottom"-most/valued nodes, in the former case, we would start at the expression graph for z and immediately determine that it's measurable and that no further traversal + rewrites are necessary. That's the easy case.

In the latter case, we would start at y and determine that it's only measurable if x is measurable, so we need to walk up the inputs of y and potentially rewrite those so that they're measurable. In this example, no rewrites are necessary, but the general logic still requires that the rewrites be applied by same logic that determines whether or not something is measurable. This is the dependency that conditional rewrite applications introduce.

Taking the above literally and creating a routine based simply on walk and manually applications of rewrites would result in a very convoluted and error-prone interface/development context. If we want to take a more efficient approach, we need to carefully consider how best to handle this dependency without foregoing a lot of the existing graph and rewrite tooling.

There may be a way to do this using "preconditioning" passes that truncate graphs via their naive "base" measurability. For example, we could walk up graphs and determine the relevant subgraphs using isinstance(..., MeasurableVariable) (i.e. the atomic/"base" measurable nodes):

I = at.random.bernoulli(0.5)
C = at.random.normal(0, 1)
A = at.clip(C, -1, 1)
B = at.random.normal(0, 1)
R = at.ifelse(I, A, B)
X = at.random.normal(R)
Y = at.clip(X, 0, 1)

# Traversal/subgraph should stop at `Y -> X`
joint_logprob({Y: y_vv})

# Traversal/subgraph should stop at `X`
joint_logprob({X: x_vv})

# Traversal/subgraph should stop at `R -> [I, A -> C, B]` (i.e. the immediately
# measurable `I`, `C`, `B`)
joint_logprob({R: r_vv})

We could also make it possible to determine the non-measurable Ops for which we have rewrites and use that to better trim the subgraphs.

Anyway, the point is that this approach produces subgraphs for which broad application of the rewrites is generally necessary (i.e. any rewrite that can be applied needs to be applied), but no changes to the rewrites or the way they're applied are needed.

@brandonwillard brandonwillard linked an issue Mar 19, 2022 that may be closed by this pull request
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An approach similar to this could/should be implemented using a custom Rewriter class. The changes in #133 should make that easier to do, now that the relevant rewrites are a little more centralized.

The idea is that we could put the nested/composite-needing IR rewrites into their own DB that uses the special Rewriter. We would still need a way to register specific composite IR Ops and which of their inputs need to be assigned temporary value variables and/or traversed, but it could all be generalized in this context.

Regardless, we can just as easily do this after implementing a quicker solution like #129, so we can close this for now and focus on that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request request discussion
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow for nested logprob rewrites
2 participants