-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speedup implementation of multivariate_normal and allow method
of covariance decomposition
#1203
Conversation
if mean is None: | ||
mean = np.array([0.0], dtype=dtype) | ||
if cov is None: | ||
cov = np.array([[1.0]], dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These were dumb defaults, just removed them: #833
It made sense when both were None perhaps, but not just one of them. Anyway, numpy doesn't provide defaults either. In PyMC we do, because we are not trying to mimick numpy API there.
b96608f
to
467fe81
Compare
Numba supports SVD via |
467fe81
to
964cccb
Compare
I thought it didn't my bad. Numba now also supports the different modes |
method
of covariance decomposition
e295c66
to
ceecfb0
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (74.00%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1203 +/- ##
==========================================
- Coverage 82.26% 82.25% -0.02%
==========================================
Files 186 186
Lines 47962 47981 +19
Branches 8630 8630
==========================================
+ Hits 39456 39465 +9
- Misses 6347 6356 +9
- Partials 2159 2160 +1
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, just one random docstring looked wrong.
Also remove bad default values
ceecfb0
to
e2fb8d1
Compare
This PR speedups the python (default) implementation of multivariate_normal, by at least 10x - 100x (the latter in the case of batch parameters).
It also allows specifying the method of decomposition, which defaults to cholesky for performance. This is used also in the JAX dispatch, but not the Numba impl (because we don't have those yet right @jessegrabowski ?)
Also removed the dumb defaults, which closes #833
Compared to numpy:
With batch mean (numpy doesn't support it):
With batch covariance - a trivial one though (numpy doesn't support it):
In the long term we should still go with a symbolic graph as discussed in #1115 so that other rewrites can happen on top of the graph, such as avoiding a useless cholesky, if the covariance is built symbolically from a cholesky to begin with (as in most PyMC models).
However that requires some coordination with PyMC as that object couldn't be a
RandomVariable
Op anymore. Also some nice rewrites we have now may not work with the symbolic representation.📚 Documentation preview 📚: https://pytensor--1203.org.readthedocs.build/en/1203/