Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup implementation of multivariate_normal and allow method of covariance decomposition #1203

Merged
merged 2 commits into from
Feb 13, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 12, 2025

This PR speedups the python (default) implementation of multivariate_normal, by at least 10x - 100x (the latter in the case of batch parameters).

It also allows specifying the method of decomposition, which defaults to cholesky for performance. This is used also in the JAX dispatch, but not the Numba impl (because we don't have those yet right @jessegrabowski ?)

Also removed the dumb defaults, which closes #833

import pytensor
import pytensor.tensor as pt
import numpy as np

rv = pt.random.multivariate_normal([0, 0, 0], cov=np.eye(3))
rng = rv.owner.inputs[0]
next_rng = rv.owner.outputs[0]

fn = pytensor.function([], rv, updates={rng: next_rng})

%timeit fn()
# Before PR:
# 335 μs ± 78.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# After PR:
# 32 μs ± 3.32 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Compared to numpy:

zeros = np.zeros(3)
eye = np.eye(3)
rng = np.random.default_rng()

# Default method uses SVD, so unsurprisingly slower
%timeit rng.multivariate_normal(zeros, eye)
# 90.7 μs ± 140 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

%timeit rng.multivariate_normal(zeros, eye, method="cholesky")
# 19.1 μs ± 240 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

With batch mean (numpy doesn't support it):

rv = pt.random.multivariate_normal(np.random.normal(size=(100, 3)), cov=np.eye(3))
rng = rv.owner.inputs[0]
next_rng = rv.owner.outputs[0]

fn = pytensor.function([], rv, updates={rng: next_rng})

%timeit fn()
# Before PR:
# 54.1 ms ± 3.88 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# After PR:
# 42.5 μs ± 3.29 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

With batch covariance - a trivial one though (numpy doesn't support it):

rv = pt.random.multivariate_normal(np.zeros(3), cov=np.broadcast_to(np.eye(3), (100, 3, 3)))
rng = rv.owner.inputs[0]
next_rng = rv.owner.outputs[0]

fn = pytensor.function([], rv, updates={rng: next_rng})

%timeit fn()
# Before PR:
# 30.1 ms ± 4.39 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# After PR:
# 60.2 μs ± 763 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In the long term we should still go with a symbolic graph as discussed in #1115 so that other rewrites can happen on top of the graph, such as avoiding a useless cholesky, if the covariance is built symbolically from a cholesky to begin with (as in most PyMC models).

However that requires some coordination with PyMC as that object couldn't be a RandomVariable Op anymore. Also some nice rewrites we have now may not work with the symbolic representation.


📚 Documentation preview 📚: https://pytensor--1203.org.readthedocs.build/en/1203/

Comment on lines -899 to -902
if mean is None:
mean = np.array([0.0], dtype=dtype)
if cov is None:
cov = np.array([[1.0]], dtype=dtype)
Copy link
Member Author

@ricardoV94 ricardoV94 Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were dumb defaults, just removed them: #833

It made sense when both were None perhaps, but not just one of them. Anyway, numpy doesn't provide defaults either. In PyMC we do, because we are not trying to mimick numpy API there.

@jessegrabowski
Copy link
Member

Numba supports SVD via np.linalg.svd but you're not allowed to set compute_uv = False. We have to set it to True for this application, so I think it should work. Cholesky and eig is also supported.

@ricardoV94
Copy link
Member Author

Numba supports SVD via np.linalg.svd but you're not allowed to set compute_uv = False. We have to set it to True for this application, so I think it should work. Cholesky and eig is also supported.

I thought it didn't my bad. Numba now also supports the different modes

@ricardoV94 ricardoV94 changed the title Faster python implementation of multivariate_normal Speedup implementation of multivariate_normal and allow method of covariance decomposition Feb 12, 2025
@ricardoV94 ricardoV94 force-pushed the speedup_mvnormal branch 2 times, most recently from e295c66 to ceecfb0 Compare February 12, 2025 16:58
Copy link

codecov bot commented Feb 12, 2025

Codecov Report

Attention: Patch coverage is 74.00000% with 13 lines in your changes missing coverage. Please review.

Project coverage is 82.25%. Comparing base (7411a08) to head (e2fb8d1).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/random.py 8.33% 11 Missing ⚠️
pytensor/tensor/random/basic.py 93.10% 1 Missing and 1 partial ⚠️

❌ Your patch status has failed because the patch coverage (74.00%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1203      +/-   ##
==========================================
- Coverage   82.26%   82.25%   -0.02%     
==========================================
  Files         186      186              
  Lines       47962    47981      +19     
  Branches     8630     8630              
==========================================
+ Hits        39456    39465       +9     
- Misses       6347     6356       +9     
- Partials     2159     2160       +1     
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/random.py 93.70% <100.00%> (+0.20%) ⬆️
pytensor/tensor/random/basic.py 98.84% <93.10%> (-0.39%) ⬇️
pytensor/link/numba/dispatch/random.py 57.20% <8.33%> (-1.78%) ⬇️

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, just one random docstring looked wrong.

@ricardoV94 ricardoV94 merged commit 2aecb95 into pymc-devs:main Feb 13, 2025
63 of 64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Default MvNormal covariance doesn't make sense
2 participants