Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API of aeppl.factorized_joint_logprob #85

Closed
aseyboldt opened this issue Nov 9, 2021 · 5 comments
Closed

API of aeppl.factorized_joint_logprob #85

aseyboldt opened this issue Nov 9, 2021 · 5 comments

Comments

@aseyboldt
Copy link
Contributor

During discussion of pymc-devs/pymc#5155 we (@ricardoV94 and I) also talked a bit about the API for aeppl.factorized_joint_logprob and how it deals with random variables that are not specified at all.

I think it might be better to change the api a bit so that all random variables in the graph have to be specified somehow. That would make it more future proof (eg implement marginalization using closed form solution or some kind of numerical integration) and also make it closer to the math notation and harder for users to shoot themselves into the foot.

Let's say we have a graph like this:

import aesara
import aesara.tensor as at
import aeppl

a = at.random.normal()
b = at.random.normal(loc=a)

a_val = at.dscalar()
b_val = at.dscalar()

We could now be interested in those values (math notation):

# joint logp
P(a = a_val, b = b_val)

# some conditional logps
P(a = a_val | b = b_val) = P(a = a_val)
P(b = b_val | a = a_val)

# marginal logps
P(a = a_val) = \int_b_val P(a = a_val | b = b_val) P(b = b_val)  # integral doesn't do anything...
P(b = b_val) = \int_a_val P(b = b_val | a = a_val) P(b = b_val)  # this one does

# Random variables whose expectation is the marginal logp (but that might converge *really* slowly or even not at all in many cases)
P_rv(a = a_val | b = B), where B ~ b  # not sure actually how to write this down properly
P_rv(b = b_val | a = A), where A ~ a

If we call aeppl.factorized_joint_logprob(all_vars), we get all the conditional logps: {key_var: P(key_var = key_val | all_remaining) for (key_var, key_val) in all_vars}, which seems fine to me (although I might only be interested in some of those in the first place, but I think this is fine anyway).

If we call aeppl.factorized_joint_logprob({b: b_val}) however, we actually get the random variable instead of the marginal logp, even though the call looks almost like the math notation of the marginal logp.

We could make this choice explicit by switching this use case to an api like this:

aeppl.factorized_joint_logprob({b: b_val}, sample=[a])

Which we could later extend to marginalization using something those:

aeppl.factorized_joint_logprob({b: b_val}, marginalize=[a])  # closed form solution only
aeppl.factorized_joint_logprob({b: b_val}, marginalize=[a], integration_options={a: "gauss_hermite(21)"}) # some numerical integration
@brandonwillard
Copy link
Member

We had some exceptions and—later—warnings to this effect, but those approaches inherently restricted one's ability to intentionally include random sampling in the resulting log-probability graphs, and we don't want to trade those kinds of fundamental limitations for high-level reporting conveniences.

Simply put, factorized_joint_logprob is meant to be as simple as possible and only perform the function of applying log-probability-specific rewrites and calling _logprob on the MeasurableVariables results. If anything, we need to consider a better separation of those two steps, so that they can be customized more easily.

Regarding the general design of this library, a keyword like marginalize seems to go in the opposite direction. We're aiming for composability, and a distinct marginalization function/capability is more inline with that. Otherwise, marginalization can be included through additions to the rewrites that factorized_joint_logprob performs.

Nevertheless, there's nothing preventing one from creating an interface to factorized_joint_logprob that works the way you describe, but, from a design and implementation point of view, we can't make such an interface an explicit target for the project.

It might make sense to allow additional Features to be added to the FunctionGraph, and these Features could do whatever high-level analysis of the graphs that one might want.

Really, we want things to be as configurable/malleable as possible, but we don't want to invent and maintain new interfaces and APIs if/when we don't need to. Instead, we need to use and improve the existing ones. This also helps the project focus on its core functionality and avoid over-extending itself into marginal use-case territories.

@ricardoV94
Copy link
Contributor

ricardoV94 commented Nov 9, 2021

I think the main point is what would be the most useful thing aeppl could (try) to do when you call:

aeppl.joint_logprob({b: b_val},

Right now what we do (and document) is that a is left as an input to the logprob of b, and depending on whether a has auto-updates enabled or not, you either get a logprob for b that is stochastic or fixed on whatever the initial seed was for a. I would agree with @aseyboldt that either case is very unlikely to be what users want or need from a PPL (although sometimes they will want that, and I don't think we should prevent it either).

Having the library try to marginalize the variable of a would be a much more useful feature for a PPL. We don't have any functionality like this but it would be an interesting avenue to explore. In this case we would need a way to specify those cases where we do want a to be left as a simple RandomVariable as we do now, but I don't see why that would be the most useful or desired default.

@ricardoV94
Copy link
Contributor

ricardoV94 commented Nov 10, 2021

Also worth noting that we don't interpret

x_rv = at.random.normal()
y_rv = at.random.normal()

z_rv = x_rv + y_rv
z_vv = z_rv.clone()

joint_logprob({z_rv: z_vv})

As being a stochastic graph of both x_rv and y_rv (or one centered on the other), even though z_rv follows exactly the same distribution as in the graph below:

x_rv = at.random.normal()
z_rv = at.random.normal(x_rv)

z_vv = z_rv.clone()
joint_logprob({z_rv: z_vv})

@brandonwillard
Copy link
Member

brandonwillard commented Nov 10, 2021

Also worth noting that we don't interpret

x_rv = at.random.normal()
y_rv = at.random.normal()

z_rv = x_rv + y_rv
z_vv = z_rv.clone()

joint_logprob({z_rv: z_vv})

As being a stochastic graph of both x_rv and y_rv (or one centered on the other), even though z_rv follows exactly the same distribution as in the graph below:

x_rv = at.random.normal()
z_rv = at.random.normal(x_rv)

z_vv = z_rv.clone()
joint_logprob({z_rv: z_vv})

I don't follow; are we assuming that the convolution z_rv is being performed by AePPL in the first example?

@ricardoV94
Copy link
Contributor

ricardoV94 commented Nov 10, 2021

I don't follow; are we assuming that the convolution z_rv is being performed by AePPL in the first example?

Yeah, I am assuming that's what aeppl would try to do if we had already implemented the convolution rewrite (which I think was discussed as a future feature)

@aesara-devs aesara-devs locked and limited conversation to collaborators Nov 1, 2022
@rlouf rlouf converted this issue into discussion #198 Nov 1, 2022

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Projects
None yet
Development

No branches or pull requests

3 participants