-
-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Nested logprob rewrites proof of concept #102
Conversation
Codecov Report
@@ Coverage Diff @@
## main #102 +/- ##
==========================================
+ Coverage 93.89% 93.92% +0.02%
==========================================
Files 10 10
Lines 1426 1432 +6
Branches 210 212 +2
==========================================
+ Hits 1339 1345 +6
Misses 50 50
Partials 37 37
Continue to review full report at Codecov.
|
05d0c68
to
0595a76
Compare
0595a76
to
82c8bff
Compare
It's better to think of this process as first producing a "measurable" intermediate representation (IR) of the original model graph. The fact that the value variable checking/manipulation is also done at that point is mostly a matter of convenience. This distinction is important, because there may be use/need for the identification of measurable terms independent of value variables, to which it seems you're ultimately alluding. It's possible that the entire scenario you're describing could be handled by simply loosening the overly strict value variable requirements instead of adding a new placeholder element and associated logic. In other words, why work around/within existing restrictions when they might be unnecessary in the first place? If we don't consider these kinds of things first and foremost, we'll likely end up building something that becomes more and more difficult to develop. |
#78 is an answer (of sorts) to this. The idea is to put value variables in the graph, then we don't need a |
That was my other alternative, just eagerly convert / flag all nodes that can be made measurable as such, but it seemed like this would be a bit insane/expensive the more I thought about it. |
What made you think that? |
In my mind it meant we would end up replacing almost all nodes in a graph by equivalent "measurable" ones, specially as we increase the "coverage" for the basic Aesara Ops. |
Yes, how else would we do these kinds of things? In other words, if a mixture needs to consist of The value variables are used to determine the subgraphs to which we apply IR rewrites in the first place (i.e. they determine the conditional log-probabilities), and those subgraphs will be the outputs of the One could make a distinction between taking a bottom-up and top-down approach, but I don't think there's a meaningful difference between the two in this scenario. We already know that the "bottom"/outputs of the Even so, we always have control over how the rewrites traverse the graphs, so there's nothing preventing us from applying them more intelligently at a high level and avoiding unnecessary applications. |
I don't quite see how do we distinguish between the following two cases (if we have to): x = at.random.normal()
y = at.clip(x, 0, 1)
z = at.random.normal(y)
x_vv = x.clone()
y_vv = y.clone()
z_vv = z.clone()
# not measuring clip node
joint_logprob({x: x_vv, z: z_vv})
# measuring the clip node
joint_logprob({y: y_vv, z: z_vv}) In both cases Currently we distinguish between the two based on whether it has a value variable or not, but if I understand you correctly, we would now always rewrite a clip so that it would work in nested expressions like the mixture example above. Alternatively, there is some information already that we don't want to convert it in the first case because the measurement input from which the logprob derives, # random logprob based on x and a deterministic clipping
joint_logprob({z: z_vv}) |
Yes, we're currently performing (IR) normalizing/canonicalizing rewrites to the entire graph, even if there are some "upstream" terms that may be irrelevant. This is the case in the former example, where the measurability of the That's done because it's (generally) cheap and it provides a simplicity to our rewrites and the way in which they're applied. It's basically a "greedy" approach. As I mentioned above, we can always control the graph traversal and application of rewrites so that certain rewrites are only applied upstream from a valued variable when they're needed; however, complications arise in this scenario due to an added dependency between the traversal logic and the rewrite rules themselves. For instance, starting at the "bottom"-most/valued nodes, in the former case, we would start at the expression graph for In the latter case, we would start at Taking the above literally and creating a routine based simply on There may be a way to do this using "preconditioning" passes that truncate graphs via their naive "base" measurability. For example, we could walk up graphs and determine the relevant subgraphs using I = at.random.bernoulli(0.5)
C = at.random.normal(0, 1)
A = at.clip(C, -1, 1)
B = at.random.normal(0, 1)
R = at.ifelse(I, A, B)
X = at.random.normal(R)
Y = at.clip(X, 0, 1)
# Traversal/subgraph should stop at `Y -> X`
joint_logprob({Y: y_vv})
# Traversal/subgraph should stop at `X`
joint_logprob({X: x_vv})
# Traversal/subgraph should stop at `R -> [I, A -> C, B]` (i.e. the immediately
# measurable `I`, `C`, `B`)
joint_logprob({R: r_vv}) We could also make it possible to determine the non-measurable Anyway, the point is that this approach produces subgraphs for which broad application of the rewrites is generally necessary (i.e. any rewrite that can be applied needs to be applied), but no changes to the rewrites or the way they're applied are needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An approach similar to this could/should be implemented using a custom Rewriter
class. The changes in #133 should make that easier to do, now that the relevant rewrites are a little more centralized.
The idea is that we could put the nested/composite-needing IR rewrites into their own DB that uses the special Rewriter
. We would still need a way to register specific composite IR Op
s and which of their inputs need to be assigned temporary value variables and/or traversed, but it could all be generalized in this context.
Regardless, we can just as easily do this after implementing a quicker solution like #129, so we can close this for now and focus on that.
This PR is a proof of concept for how we could handle composite RVs that require nested aeppl rewrites. The goal is to figure out if this approach is too limited / cumbersome and, if so, what could work better.
Roughly, most of our rewrites come in two parts:
_logprob
can be safely dispatchedIn step 2, we often need inputs that are themselves measurable variables, but which don't have value variables assigned to them:
aeppl/aeppl/mixture.py
Lines 283 to 287 in 05d0c68
If these inputs are RandomVariables things are fine, as those are measurable by default. Otherwise, we might still be dealing with things that could be measurable... but we will never know because they don't (and shouldn't) have value variables and their corresponding rewrites won't be triggered due to step 1.
The illustrating example in this PR is a scalar mixture that has a clipped variable as one of it's components:
I have also added a test example with nested scalar mixtures
Suggestion
The suggestion is that any rewrite that depends on other inputs being measurable, can assign a temporary value variable
UPSTREAM_VALUE
to those inputs. This value variable is automatically discarded by thePreserveRVMappings
when such variables are converted to measurable ones by their own rewrites.Limitations
This would work for some current and future rewrites like #26, but not for everything.
For example nested rewrites that depend on the direct manipulation / creation of value variables such as inc_subtensor and scans would not work. This would also not work when the
_logprob
dispatched function depends on having RandomVariables as inputs, as happens with non-scalar mixtures:aeppl/aeppl/mixture.py
Line 354 in 05d0c68
Also depending on the order with which rewrites are called this may not always work, because the attribution of the
UPSTREAM_VALUE
is not a visible graph change, and I guess the Equilibrium optimization will stop if after passing through all nodes, none is changed