Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

uniform dirichlet prior with stickbreaking transform + ADVI #4733

Open
harrig12 opened this issue Jun 2, 2021 · 4 comments
Open

uniform dirichlet prior with stickbreaking transform + ADVI #4733

harrig12 opened this issue Jun 2, 2021 · 4 comments
Labels
question VI Variational Inference

Comments

@harrig12
Copy link

harrig12 commented Jun 2, 2021

Description of your problem

My dirichlet prior does not appear to be behaving as I would expect when using ADVI. For uniform a=1, the posterior density of the last element is way off, and there are lots of divergences in the traceplot.

Reading through #4129 I wonder if it may have to do with Km1, because it's noticeably exacerbated when the size of the dirichlet is increased. (here I set to 30 to demonstrate).

The reason I noticed this is because my "uniform" prior does not actually look uniform at all. The unexpected behavior is lessened over the course of training, as my model learns - but in some settings it is greatly hampered by the very biased prior that is apparently being created. I've only fit the ADVI trace with a single step so as to show this.

Works as expected with NUTS

import pymc3 as pm
import numpy as np
import pandas as pd

with pm.Model() as model:
    decomp = pm.Dirichlet('decomp', np.ones(30), shape=30)
    trace1 = pm.sample(100, return_inferencedata=False)
    pm.plot_trace(trace1, var_names = 'decomp');
       
pd.DataFrame(trace1['decomp_stickbreaking__']).plot.kde(figsize=(10,4), legend=False);
pd.DataFrame(trace1['decomp']).plot.kde(figsize=(10,4), legend=False);

image

Strange result for the last dirichlet element when using ADVI

with pm.Model() as model:
    decomp = pm.Dirichlet('decomp', np.ones(30), shape=30)
    trace2 = pm.ADVI() 
    trace2.fit(1)
    pm.plot_trace(trace2.approx.sample(100), var_names = 'decomp');

pd.DataFrame(trace2.approx.sample(100)['decomp_stickbreaking__']).plot.kde(figsize=(10,4), legend=False);
pd.DataFrame(trace2.approx.sample(100)['decomp']).plot.kde(figsize=(10,4), legend = False);

image

Versions and main components

  • PyMC3 Version: 3.11.2
  • Aesara Version: n/a
  • Python Version: 3.8.8
  • Operating system: CentOS Linux 7 (Core)
  • How did you install PyMC3: conda
@harrig12
Copy link
Author

harrig12 commented Jul 28, 2021

I'm pretty sure this is a bug in the stickbreaking transform. I have the same problem with a small dimensional dirichlet (ex. K=5)

If anyone else finds themselves here, my solution is to draw a bunch of gammas as a workaround to get desired behaviour. (Gamma transformation default is log transform, not stickbreaking). The simplex constraint is now enforced in the untransformed space.

Dirichlet bad:

Dirichlet bad

Gamma good:

Gamma good

@fonnesbeck
Copy link
Member

I wonder if the problem would also go away if you used fullrank_advi instead of the default diagonal ADVI?

@harrig12
Copy link
Author

same issue, slightly different shape in the posterior... interesting!

image

@harrig12
Copy link
Author

similarly observed with methods svgd and asvdg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question VI Variational Inference
Projects
None yet
Development

No branches or pull requests

4 participants