Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration of aeppl with PyMC #4887

Merged
merged 1 commit into from
Oct 23, 2021
Merged

Conversation

kc611
Copy link
Contributor

@kc611 kc611 commented Jul 28, 2021

This PR refactors PyMC's _logp framework to be able to use aeppl's joint_logprob framework.

The changes this PR aims to make:

  • Registering the Log-probability functions of PyMC's distributions on _logprob instead of _logp so that joint_logprob can use them internally.
  • Separating log-cdf logic handling from logp logic (currently both are being handled in the same logpt function)
  • Removing redundant logp from Distribution classes (since _logprob registration present in aeppl, for e.g. NormalRV`)
  • Calling factorized_joint_logprob inside logpt function. and then return the appropriate RVs (since joint_logprob tends to return all of the RV's logp present in the graph)
  • Switching from pymc's transforms to aeppl's transforms.
    • Refactor transforms.py and remove transforms already present in aeppl
    • Updating transform related logic to match aeppl's.
    • Passing transforms to joint_logprob correctly
    • Refactoring test_transforms.py
  • Investigate test failures.
    • test_distributions failures.
    • test_transforms.py test failures.
    • test_logprob failures.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 28, 2021

I would guess the "Mixture" distribution in PyMC3 would only have to be concerned with returnig the right random graph, from which the logprob can be derived with the aeppl.

I am not sure how feasible it would be to plug the aeppl.joint_logprob for individual variables, as opposed to using it for the entire model.

@kc611
Copy link
Contributor Author

kc611 commented Jul 28, 2021

One thing we can do over here is :

https://github.com/pymc-devs/pymc3/blob/819f045ad36d1d8b18651528384972dc2bea8213/pymc3/distributions/distribution.py#L97

Registering the logps to _logprob from aeppl instead of _logp. This will allow us to make Mixture's of RV's not present in aeppl but are present in PyMC (Which is something we'd obviously want to support. ) The problem however comes with the RV's present in both such as Normal and Uniform. They'll get registered twice.

@codecov
Copy link

codecov bot commented Jul 28, 2021

Codecov Report

Merging #4887 (a854515) into main (07a95fa) will increase coverage by 1.15%.
The diff coverage is 93.56%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #4887      +/-   ##
==========================================
+ Coverage   75.31%   76.46%   +1.15%     
==========================================
  Files          87       87              
  Lines       14106    14027      -79     
==========================================
+ Hits        10624    10726     +102     
+ Misses       3482     3301     -181     
Impacted Files Coverage Δ
pymc/distributions/__init__.py 100.00% <ø> (ø)
pymc/sampling_jax.py 0.00% <0.00%> (ø)
pymc/distributions/transforms.py 91.95% <89.28%> (-0.44%) ⬇️
pymc/model.py 83.79% <95.83%> (+0.60%) ⬆️
pymc/distributions/logprob.py 94.93% <98.55%> (-0.96%) ⬇️
pymc/aesaraf.py 90.10% <100.00%> (-1.38%) ⬇️
pymc/bart/bart.py 97.01% <100.00%> (+0.40%) ⬆️
pymc/distributions/bound.py 100.00% <100.00%> (ø)
pymc/distributions/continuous.py 95.59% <100.00%> (-0.25%) ⬇️
pymc/distributions/discrete.py 98.02% <100.00%> (+<0.01%) ⬆️
... and 18 more

@kc611 kc611 force-pushed the add_mixture branch 2 times, most recently from bf606eb to 5024154 Compare August 12, 2021 10:22
@kc611 kc611 changed the title [WIP] Refactoring Mixture distributions using aeppl [WIP] Integration of aeppl with PyMC Aug 12, 2021
@kc611 kc611 force-pushed the add_mixture branch 2 times, most recently from 1c79a25 to f9e7d55 Compare August 30, 2021 16:48
@kc611 kc611 force-pushed the add_mixture branch 2 times, most recently from 3949c80 to f452a50 Compare September 4, 2021 18:10
@twiecki
Copy link
Member

twiecki commented Sep 23, 2021

Once the aeppl integration is done, refactor Mixture to return a graph in a form expected by joint_logprob i.e. using at.stack

I think we can push this to a future PR, as Mixture hasn't been refactored to begin with.

@kc611
Copy link
Contributor Author

kc611 commented Oct 15, 2021

Note that most of the test failures in Ubuntu's float32 systems are actually optimization errors and not code failures. Everything that fails with a AttributeError: 'numpy.bool_' object has no attribute 'type' actually runs but fails a certain optimization along the way. (And pytest detects it as a code failure.)

There's already an issue for that: aesara-devs/aesara#616

The only concerning thing that remains is pymc/tests/test_posteriors.py::TestSliceUniform on float32 systems which actually fails with AssertionErrors. So far I haven't pinned down the exact issue with that test.

@twiecki
Copy link
Member

twiecki commented Oct 19, 2021

@kc611 now that aesara-devs/aesara#616 is fixed, is this done? Maybe we mark that slice test with xfail and move on?

@kc611
Copy link
Contributor Author

kc611 commented Oct 19, 2021

Yeah sure.

@kc611
Copy link
Contributor Author

kc611 commented Oct 19, 2021

Looks like one of the subclass's tests of pymc/tests/test_posteriors.py::TestSliceUniform passes while others don't. I can't put the xfail decorator in subclass's test since there are other class's tests being derived from it. Is there anyway we can temporarily relax the strict xpass condition for that particular test ?

Edit: Got around it by temporarily failing the test for that particular class, the failure can be removed once the entire test is fixed. A bit crude method butsince the logic is actually passing on other systems such as float64 linux and windows systems any logical errors arising from other changes/other PRs will be caught there so I don't see any issues with failing it like this temporarily for a specific subsystem.

Any other solutions are also welcome.

@twiecki
Copy link
Member

twiecki commented Oct 20, 2021

@kc611 Tests are passing 🥳 can we merge?

@@ -83,6 +84,18 @@ def perform(self, node, inputs, outputs):
raise NotImplementedError()


@jax_funcify.register(Assert)
def jax_funcify_Assert(op, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker, but shouldn't this go into Aesara?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is supposed to be just a temporary pass-through for Assert Op since Jax doesn't allow asserts in it's code. Though you're right we should be adding this functionality (maybe with a good workaround) in Aesara, I'll open an issue for this.

@ricardoV94
Copy link
Member

Just a reminder this PR should include a test for #5007

I can push something later if you are too busy now @kc611

@ricardoV94 ricardoV94 force-pushed the add_mixture branch 6 times, most recently from 5ec3e0b to 3b828a7 Compare October 21, 2021 16:37
@ricardoV94
Copy link
Member

Pushed a test that works and rebased from main. Apologies if I messed up anything @kc611!

@@ -4,6 +4,7 @@ channels:
- conda-forge
- defaults
dependencies:
- aeppl>=0.0.13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disclaimer: I skimmed the discussion and have not reviewed the code

What does this dependency mean for the users? Will using aeppl functions/methods be necessary? Or will it be handled under the hood and advanced users will be able to use it if they want?

This is a minor-medium concern for the beta release, but aeppl seems to be very little documented. I was unable to find the link to it's docs anywhere. I also did check and saw there was a gh-pages branch so I went to https://aesara-devs.github.io/aeppl/ which has some api docs. We need to work on aeppl docs if we expect pymc users to work with it directly before releasing stable 4.0

Copy link
Member

@ricardoV94 ricardoV94 Oct 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users don't need to know about aeppl, it will be used by us developers to create more complex distributions.

Indeed we need documentation building on aeppl, we have a PR opened for that, but we welcome help. All our methods and functions are quite well documented on the other hand

@twiecki
Copy link
Member

twiecki commented Oct 23, 2021

Can we merge?

@ricardoV94 ricardoV94 merged commit 0a172c8 into pymc-devs:main Oct 23, 2021
@ricardoV94
Copy link
Member

ricardoV94 commented Oct 23, 2021

Pulled the trigger, 🤞

Great work @kc611

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants