-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Derive logprob for exp2, log2, log10, log1p, expm1, log1mexp, log1pexp (softplus), and sigmoid transformations #6826
Derive logprob for exp2, log2, log10, log1p, expm1, log1mexp, log1pexp (softplus), and sigmoid transformations #6826
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6826 +/- ##
==========================================
+ Coverage 92.05% 92.16% +0.10%
==========================================
Files 96 100 +4
Lines 16448 16877 +429
==========================================
+ Hits 15142 15554 +412
- Misses 1306 1323 +17
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for picking this up. My suggestions are all about being a bit lazier (less code to maintain and test), but the idea is totally right!
You don't need to worry about jacobian when rewriting into any equivalent forms, not just those. For instance for stuff like |
pymc/logprob/transforms.py
Outdated
[inp] = node.inputs | ||
|
||
if isinstance(node.op.scalar_op, Exp2): | ||
return [pt.power(2, inp)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we support this one, only powers with fixed exponent and variable base
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can convert to exp(ln(2)*x)
instead, which PyMC will know how to handle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should actually use that for any power(const, x) -> exp(log(const) * x)
which we currently don't support. But maybe that's better left for another PR?
It requires checking we are interested in x and not const.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah I had a feeling that one may be an issue, I'll make the change. I'll add a new function to generalise this power(const, x) -> exp(log(const) * x) functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I've just pushed this new functionality. Do we care that it won't work for const <= 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's open a separate issue / PR for this. I thought it over a bit and I think we can definitely support a couple of cases.
power(const, x), for any const > 0 and any x
power(const, x), for any const and discrete x (we can play with `log(abs(neg_const))` and x's parity)
The first case we don't have to constrain ourselves to actual "constants", we can add a symbolic assert that const > 0
.
The second requires us to implement transforms for discrete variables, which would probably need #6360 first, so we can focus on the first case, which is also probably more useful anyway.
We just have to make sure not to rewrite stuff like power(x, const)
accidentally as those are implemented via our PowerTransform
. This can be done by checking which of the inputs has a path to unvalued random variables.
Rebasing from main should unstuck the tests |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
I imagine you pull --merge instead of pull --rebase? This makes github show unrelated commits in the PR :/ |
Woops, sorry about that I will fix |
0f50a55
to
0cfd1f2
Compare
It seems that the code coverage test is failing because if the new node rewriting functions. Should we be testing them? Or are they fine because they rely on existing functionality thats already tested? |
We should test them. We can test that the output def test():
base_rv = pt.random.normal(name="base_rv")
vv = pt.scalar("vv")
logp_test = logp(pt.log1p(base_rv), vv)
logp_ref = logp(pt.log(1 + base), vv)
assert equal_computations([logp_test], [logp_ref]) |
@LukeLB shall we push the PR to the finish line? Let me know if you don't have the time right now |
Hey really sorry about the lack of communication, I've been on holiday. I'm going to have a look at this this week. |
No worries, I hope you had a good holiday! |
Ahhh done the same thing as before and somehow merged unrelated commits! Will fix on my side |
While I fix that I am having a problems with two of the tests. For the test TRANSFORMATIONS = {
"log1p": (pt.log1p, lambda x: pt.log(1 + x)),
"softplus": (pt.softplus, lambda x: pt.log(1 + pt.exp(x))),
"log1mexp": (pt.log1mexp, lambda x: pt.log(1 - pt.exp(pt.neg(x)))),
"log2": (pt.log2, lambda x: pt.log(x) / pt.log(2)),
"log10": (pt.log10, lambda x: pt.log(x) / pt.log(10)),
"exp2": (pt.exp2, lambda x: pt.exp(pt.log(2) * x)),
"expm1": (pt.expm1, lambda x: pt.exp(x) - 1),
"sigmoid": (pt.sigmoid, lambda x: 1 / (1 + pt.exp(-x))),
}
@pytest.mark.parametrize("transform", TRANSFORMATIONS.keys())
def test_special_log_exp_transforms(transform):
base_rv = pt.random.normal(name="base_rv")
vv = pt.scalar("vv")
transform_func, ref_func = TRANSFORMATIONS[transform]
transformed_rv = transform_func(base_rv)
ref_transformed_rv = ref_func(base_rv)
logp_test = logp(transformed_rv, vv)
logp_ref = logp(ref_transformed_rv, vv)
assert equal_computations([logp_test], [logp_ref]) when transform is log2 or log10 then test fails for equal computation, I'm not sure what is causing that... |
Try to look at |
@ricardoV94 OK I think I've got to the bottom of what is causing the failure. Looks like a floating point in precision in one of the nodes of the graph between the test case and the reference case. |
Hmm. Let's do a logp evaluation for that one (separate test) and check for output closeness? |
Co-authored-by: Ricardo Vieira <[email protected]>
…ion in the graph)
17123ac
to
a912b20
Compare
File changes is showing some accidental overwriting of previous changes? https://github.com/pymc-devs/pymc/pull/6826/files |
Awesome work @LukeLB. Do you want to pursue the power exponent next? |
Thanks @ricardoV94, I feel like I learnt a lot on this one!
Yep I'll start looking into that, if I run into any trouble I'll drop you a mesage on slack. |
Congrats @LukeLB! This is a major and non-trivial contribution. |
This builds upon the previous pull requests, #6664 and #6775, and completes the work of #6631.
I have attempted to rewrite the logp graph only for log1p, expm1, and log1pexp (softplus). My reasoning is that in these cases, the inputs are transformed directly without affecting the backward or log_jac_det transforms. For all other cases, I have made changes to the existing transform classes.
Although some tests are still failing (see below), I'm uncertain about the reasons for the first two failures. However, it seems that the final failure is due to a floating point error, and I'm unsure how to resolve it.
Currently, there are no specific tests written for the transforms log1p, log1mexp, log1pexp (softplus), and sigmoid. I would appreciate assistance in developing tests for these transforms.
...
Checklist
Major / Breaking Changes
New features
📚 Documentation preview 📚: https://pymc--6826.org.readthedocs.build/en/6826/