Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Scans in CustomDist #6696

Merged
merged 1 commit into from
May 23, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 28, 2023

Here is how one can now write an AR with StudentT innovations:

import numpy as np
import pymc as pm
import pytensor.tensor as pt
from pymc.pytensorf import collect_default_updates
from pytensor.scan import scan

def ar_t(nu, sigma, rho, size):
    
    def step(xtm1, nu, sigma, rho):
        x = xtm1 * rho + pm.StudentT.dist(nu=nu, sigma=sigma)
        return x, collect_default_updates([x])
    
    xs, _ = scan(
        fn=step,
        outputs_info=[pt.zeros(())],
        non_sequences=[nu, sigma, rho],
        n_steps=size[0],
    )
    
    return xs

with pm.Model() as m:
    nu = 4
    sigma = pm.HalfNormal("sigma")
    rho = pm.Uniform("rho")
    steps=100
    
    pm.CustomDist("ar_t", nu, sigma, rho, dist=ar_t, observed=np.random.randn(steps))
    
    prior = pm.sample_prior_predictive()
    posterior = pm.sample()    

The part that was not supported was that we would not be able to collect the update for the scan (used in prior/posterior predictive), and the user had no way of specifying it manually either.


📚 Documentation preview 📚: https://pymc--6696.org.readthedocs.build/en/6696/

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When reviewing I slowed down at find_default_update inside collect_default_update. This entire local function can be extracted which would make it a lot easier to review & test.

Overall it looks good, but I'd recommend to add some comments to make it comprehensible without knowing the context of this PR.

@ricardoV94
Copy link
Member Author

This entire local function can be extracted which would make it a lot easier to review & test.

It's the inner recursive function, I don't think it makes sense to be outside of the initial caller function.

@ricardoV94 ricardoV94 force-pushed the custom_dist_timeseries branch from 6fa3a31 to 3710b69 Compare May 10, 2023 12:31
@ricardoV94 ricardoV94 requested a review from michaelosthege May 10, 2023 12:32
@codecov
Copy link

codecov bot commented May 10, 2023

Codecov Report

Merging #6696 (6fa3a31) into main (c57769c) will decrease coverage by 10.99%.
The diff coverage is 100.00%.

❗ Current head 6fa3a31 differs from pull request most recent head 5532cb2. Consider uploading reports for the commit 5532cb2 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #6696       +/-   ##
===========================================
- Coverage   92.01%   81.02%   -10.99%     
===========================================
  Files          95       95               
  Lines       16180    16176        -4     
===========================================
- Hits        14888    13107     -1781     
- Misses       1292     3069     +1777     
Impacted Files Coverage Δ
pymc/distributions/distribution.py 95.23% <100.00%> (-1.41%) ⬇️
pymc/pytensorf.py 90.67% <100.00%> (-1.93%) ⬇️

... and 44 files with indirect coverage changes

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand what's going on here. The key change is just to add a check for Scan Ops and, if so, reach into the outer_out_from_outer_inp mapping and grab the rng object, raising an error if it's not found.

I also found the function inside recursive function a bit unusual. This SO thread claims there is some overhead associated with doing things this way, so I'm wondering if that would matter in a worst-case recursion (very big graph? multiple scans?).

Otherwise looks good to me. I would second a bit of documentation in the functions -- a general complaint I have about the Theano codebase is a lack of useful docstrings.

@ricardoV94
Copy link
Member Author

ricardoV94 commented May 10, 2023

I think I understand what's going on here. The key change is just to add a check for Scan Ops and, if so, reach into the outer_out_from_outer_inp mapping and grab the rng object, raising an error if it's not found.

Yup

I also found the function inside recursive function a bit unusual. This SO thread claims there is some overhead associated with doing things this way, so I'm wondering if that would matter in a worst-case recursion (very big graph? multiple scans?).

I think that's only meaningful when you are calling the outer function many times, in this case the outer function is called once and then the inner function is called multiple times. Also this is not performance critical code at all since it's done once before compiling a pytensor function. The compilation is orders of magnitude more expensive than this one time graph recursion. And this compilation is itself order of magnitude faster than the presumed heavy use of the final compiled function.

The comments in that thread seem to agree with this. One user also mentions the uncluttering of the module namespace which is the main motivation I had.

Otherwise looks good to me. I would second a bit of documentation in the functions -- a general complaint I have about the Theano codebase is a lack of useful docstrings.

Sure

@ricardoV94 ricardoV94 force-pushed the custom_dist_timeseries branch from 3710b69 to 4e6510f Compare May 17, 2023 14:09
@ricardoV94

This comment was marked as outdated.

@ricardoV94 ricardoV94 force-pushed the custom_dist_timeseries branch from 4e6510f to 2977ba8 Compare May 19, 2023 11:42
This allows proper seeding in CustomDists with Scans
@ricardoV94 ricardoV94 force-pushed the custom_dist_timeseries branch from 2977ba8 to 5532cb2 Compare May 19, 2023 11:43
@ricardoV94
Copy link
Member Author

Updated docstrings and added one code example

@ricardoV94 ricardoV94 requested a review from jessegrabowski May 19, 2023 11:44
@ricardoV94 ricardoV94 merged commit bfbc8cc into pymc-devs:main May 23, 2023
@ricardoV94 ricardoV94 deleted the custom_dist_timeseries branch June 5, 2023 16:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants