-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support logp derivation of power exponent #6896
Comments
@ricardoV94 I've got the first case working locally now but I don't think its exactly what you asked for as I don't understand what you mean by
Could you provide an example? |
When you have a graph like: base_rv = pm.Poisson.dist()
x_raw_rv = pm.Normal.dist()
x_rv = pt.power(base_rv, x_raw_rv)
x_rv.name = "x"
base_rv.name = "base"
base_vv = base_rv.clone()
x_vv = x_rv.clone()
conditional_logprob({base: base_vv, x: x_vv}) In that case |
Am I using this correctly? @node_rewriter([pow, CheckParameterValue])
def measurable_power_expotent_to_exp(fgraph, node):
exponent, inp = node.inputs
# check whether inp is discrete
if inp.type.dtype.startswith("int"):
return None
return [pt.exp(pt.log(exponent) * inp)] Because while this works fine from pymc.logprob.basic import conditional_logp
base_rv = pm.Poisson.dist([1,1])
x_raw_rv = pm.Normal.dist()
x_rv = pt.power(base_rv, x_raw_rv)
x_rv.name = "x"
base_rv.name = "base"
base_vv = base_rv.clone()
x_vv = x_rv.clone()
res = conditional_logp({base_rv: base_vv, x_rv: x_vv})
res_combined = pt.sum([factor for factor in res.values()])
logp_vals_fn = pytensor.function([base_vv, x_vv], res_combined)
logp_vals_fn(np.array([2, 2]), np.array([2,2]))
>>> array(-6.87743995) If I understand you properly I would of thought this shouldn't work as base_rv can take on negative values, but the log prob does evaluate from pymc.logprob.basic import conditional_logp
base_rv = pm.Normal.dist([1,1])
x_raw_rv = pm.Normal.dist()
x_rv = pt.power(base_rv, x_raw_rv)
x_rv.name = "x"
base_rv.name = "base"
base_vv = base_rv.clone()
x_vv = x_rv.clone()
res = conditional_logp({base_rv: base_vv, x_rv: x_vv})
res_combined = pt.sum([factor for factor in res.values()])
logp_vals_fn = pytensor.function([base_vv, x_vv], res_combined)
logp_vals_fn(np.array([2, 2]), np.array([2,2]))
>>> array(-6.32902265) |
The rewrite should only be used when the exponent (not the base) is being measured. You can get this info from somewhere in Otherwise you would be breaking logp inference for stuff like
It's fine as long as it does something sensible when negative values are passed (hence my suggestion of wrapping the exponent in a CheckParameterValue). What does it evaluate to now when you pass a negative base? I guess you get a nan. You don't need to sum the two factors, just check out the one for the exponent RV. |
My changes haven't broken power(x, const), as I have also made a change to
Ah so I was being thick and using CheckParameterValue wrong i have changed to, @node_rewriter([pow])
def measurable_power_expotent_to_exp(fgraph, node):
base, inp_exponent = node.inputs
base = CheckParameterValue("base > 0")(base, pt.all(pt.ge(base, 0.0)))
# check whether inp is discrete
if inp_exponent.type.dtype.startswith("int"):
return None
return [pt.exp(pt.log(base) * inp_exponent)] When x_rv = pt.pow(-1, pt.random.normal())
x_vv = x_rv.clone()
x_logp_fn = pytensor.function([x_vv], pt.sum(logp(x_rv, x_vv)))
x_logp_fn(0.1)
When base_rv = pm.Normal.dist([2])
x_raw_rv = pm.Normal.dist()
x_rv = pt.power(base_rv, x_raw_rv)
x_rv.name = "x"
base_rv.name = "base"
base_vv = base_rv.clone()
x_vv = x_rv.clone()
res = conditional_logp({base_rv: base_vv, x_rv: x_vv})
factors = [factor for factor in res.values()]
logp_vals_fn = pytensor.function([base_vv, x_vv], factors[1])
logp_vals_fn(np.array([2]), np.array([2]))
>>> array([-1.74557279]) In this case it evaluates fine, is this what you would expect? Or does the CheckParameterValue logic need to change? |
That's what I would expect because you provided a positive value to the |
Great, I'll take a look when you open the PR |
Yep when I do that it fails as expected! I'll open a PR |
The first case we don't have to constrain ourselves to actual "constants", we can add a symbolic assert that
const > 0
.The second requires us to implement transforms for discrete variables, which would probably need #6360 first, so we can focus on the first case, which is also probably more useful anyway.
We just have to make sure not to rewrite stuff like
power(x, const)
accidentally as those are implemented via ourPowerTransform
. This can be done by checking which of the inputs has a path to unvalued random variables.Originally posted by @ricardoV94 in #6826 (comment)
The text was updated successfully, but these errors were encountered: