Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with random switch #1013

Closed
DomagojGalic opened this issue Dec 7, 2019 · 10 comments
Closed

Problem with random switch #1013

DomagojGalic opened this issue Dec 7, 2019 · 10 comments

Comments

@DomagojGalic
Copy link

I'm implementing a model from Cameron Davidson-Pilon's book bayesian methods for hackers link to the notebook.
The model I'm implementing is in the example on inferring behaviour from text-message data.

I can't seem to get the correct result.
Here is my code:

@model smscounter(y) = begin
    n = length(y)
    α = 1.0 / mean(y)
    λ1 ~ Exponential(α)
    λ2 ~ Exponential(α)
    τ ~ DiscreteUniform(1, n)
    
    for k in 1:τ
        y[k] ~ Poisson(λ1)
    end
    
    for k in+ 1):n
        y[k] ~ Poisson(λ2)
    end
end;

chain = sample(smscounter(data), MH(), 10_000);

Any comment on what I might be doing wrong would be much appreciated.
I'm using Julia 1.1.1 and Turing 0.7.1

@cpfiffer
Copy link
Member

cpfiffer commented Dec 7, 2019

Can you tell use what you are getting and what you are expecting to get?

@yebai
Copy link
Member

yebai commented Dec 7, 2019

τ is a discrete parameter. MH aren’t expected to work with discrete parameters. Maybe try a Gibbs sampler, e.g. PG+MH, or PG+HMC.

@DomagojGalic
Copy link
Author

Can you tell use what you are getting and what you are expecting to get?

I'm getting this:
Summary Statistics
. Omitted printing of 1 columns
│ Row │ parameters │ mean │ std │ naive_se │ mcse │ ess │
│ 1 │ λ1 │ 3.81492 │ 4.88523e-15 │ 4.88523e-17 │ 0.0 │ 40.1606 │
│ 2 │ λ2 │ 4.89517 │ 7.10578e-15 │ 7.10578e-17 │ 0.0 │ 40.1606 │
│ 3 │ τ │ 40.0 │ 0.0 │ 0.0 │ 0.0 │ NaN │

Quantiles

│ Row │ parameters │ 2.5% │ 25.0% │ 50.0% │ 75.0% │ 97.5% │
│ 1 │ λ1 │ 3.81492 │ 3.81492 │ 3.81492 │ 3.81492 │ 3.81492 │
│ 2 │ λ2 │ 4.89517 │ 4.89517 │ 4.89517 │ 4.89517 │ 4.89517 │
│ 3 │ τ │ 40.0 │ 40.0 │ 40.0 │ 40.0 │ 40.0 │

This is what I get when implementing the same model in pymc3
post.pdf
You may have to squint a little to see the green bars that represent the distribution of τ.

As you can see Turing model returns τ as deterministic value (standard deviation is zero), as well as λ1 and λ2 which are also basically deterministic, with means 3.8 and 4.9 respectively.
The pymc3 model returns variables with greater variability, also λ1 and λ2 have means of about 17 and 23 respectively.

Here is the python code just in case:

with pm.Model() as model:
    alpha = 1.0 / count_data.mean() 
    lambda_1 = pm.Exponential("lambda_1", alpha)
    lambda_2 = pm.Exponential("lambda_2", alpha)

    tau = pm.DiscreteUniform("tau", lower=0, upper=n_count_data - 1)

    idx = np.arange(n_count_data)
    lambda_ = pm.math.switch(tau >= idx, lambda_1, lambda_2)
    
with model:
    observation = pm.Poisson("obs", lambda_, observed=count_data)
    step = pm.Metropolis()
    trace = pm.sample(3000, tune=1000, step=step)

@DomagojGalic
Copy link
Author

τ is a discrete parameter. MH aren’t expected to work with discrete parameters. Maybe try a Gibbs sampler, e.g. PG+MH, or PG+HMC.

I tried

chain = sample(smscounter(y), Gibbs(PG(100, ), MH(:λ1, :λ2)), 10_000);

and also

chain = sample(smscounter(y), Gibbs(PG(100, ), HMC(0.05, 10, :λ1, :λ2)), 10_000);

In both cases I got
Stacktrace in the failed task:

KeyError

and then a mile long stacktrace.

@cpfiffer
Copy link
Member

cpfiffer commented Dec 8, 2019

This would be much easier to debug if we had a fully reproducible example (something we can copy-paste) that runs immediately. Can you provide a MWE (minimum working example) that we can work from? We don't have your dataset or any of your setup code prior to the model declaration.

@DomagojGalic
Copy link
Author

This would be much easier to debug if we had a fully reproducible example (something we can copy-paste) that runs immediately. Can you provide a MWE (minimum working example) that we can work from? We don't have your dataset or any of your setup code prior to the model declaration.

Sure (the dataset is actually provided with the notebook I linked, but I'm going to just copy it, since it'll be easier than downloading it and loading)

data = TArray(Int64, 74);
y = [13 24 8 24 7 35 14 11 15 11 22 22 11 57 11 19 29 6 19 12 22 12 18 72 32 9 7 13 19 23 27 20 6 17 13 10 14 6 16 15 7 2 15 15 19 70 49 7 53 22 21 31 19 11 18 20 12 35 17 23 17 4 2 31 30 13 27 0 39 37 5 14 13 22];

for i in 1:74
    data[i] = y[i]
end;

@model smscounter(y) = begin
    n = length(y)
    α = 1.0 / mean(y)
    λ1 ~ Exponential(α)
    λ2 ~ Exponential(α)
    τ ~ DiscreteUniform(1, n)
    
    for k in 1:τ
        y[k] ~ Poisson(λ1)
    end
    
    for k in+ 1):n
        y[k] ~ Poisson(λ2)
    end
end;

chain1 = sample(smscounter(data), MH(), 10_000);
# chain2 = sample(smscounter(data), Gibbs(PG(100, :τ), MH(:λ1, :λ2)), 10_000);
# chain3 = sample(smscounter(data), Gibbs(PG(100, :τ), HMC(0.05, 10, :λ1, :λ2)), 10_000);

@cpfiffer
Copy link
Member

cpfiffer commented Dec 9, 2019

The problem here appears to be the use of TArray to wrap your data. TArray is really only for arrays of parameters, not for data. If you delete the TArray stuff and just use the plain vector y with one of the Gibbs samplers, you should be fine.

Modified code with a shortened sample size for illustrative purposes:

using Turing

y = [13 24 8 24 7 35 14 11 15 11 22 22 11 57 11 19 29 6 19 12 22 12 18 72 32 9 7 13 19 23 27 20 6 17 13 10 14 6 16 15 7 2 15 15 19 70 49 7 53 22 21 31 19 11 18 20 12 35 17 23 17 4 2 31 30 13 27 0 39 37 5 14 13 22];

@model smscounter(y) = begin
    n = length(y)
    α = 1.0 / mean(y)
    λ1 ~ Exponential(α)
    λ2 ~ Exponential(α)
    τ ~ DiscreteUniform(1, n)
    
    for k in 1:τ
        y[k] ~ Poisson(λ1)
    end
    
    for k in+ 1):n
        y[k] ~ Poisson(λ2)
    end
end;

chain = sample(smscounter(y), Gibbs(PG(100, ), MH(:λ1, :λ2)), 1_000);

Output:

Object of type Chains, with data of type 1000×4×1 Array{Union{Missing, Real},3}

Iterations        = 1:1000
Thinning interval = 1
Chains            = 1
Samples per chain = 1000
internals         = lp
parameters        = λ1, λ2, τ

2-element Array{ChainDataFrame,1}

Summary Statistics
. Omitted printing of 1 columns
│ Row │ parameters │ mean      │ std       │ naive_se    │ mcse       │ ess     │
│     │ Symbol     │ Float64   │ Float64   │ Float64     │ Float64    │ Any     │
├─────┼────────────┼───────────┼───────────┼─────────────┼────────────┼─────────┤
│ 1   │ λ1         │ 0.01770120.00842830.0002665260.001994356.21103 │
│ 2   │ λ2         │ 0.3241280.05866370.001855110.01859996.14019 │
│ 3   │ τ          │ 1.0340.7598830.02402960.0346.49671 │

Quantiles

│ Row │ parameters │ 2.5%25.0%50.0%75.0%97.5%     │
│     │ Symbol     │ Float64   │ Float64   │ Float64   │ Float64   │ Float64   │
├─────┼────────────┼───────────┼───────────┼───────────┼───────────┼───────────┤
│ 1   │ λ1         │ 0.01570690.01570690.01570690.01570690.0593485 │
│ 2   │ λ2         │ 0.1222490.3427280.3427280.3427280.342728  │
│ 3   │ τ          │ 1.01.01.01.01.0

@DomagojGalic
Copy link
Author

There's still the problem of wrong results, I tried your solution and get

│ Row │ parameters │ mean │ std │ naive_se │ mcse │ ess │ r_hat │
│ 1 │ λ1 │ 5.49626 │ 3.55289e-15 │ 3.55289e-17 │ 0.0 │ 40.1606 │ 0.9999 │
│ 2 │ λ2 │ 2.97791 │ 3.997e-15 │ 3.997e-17 │ 0.0 │ 40.1606 │ 0.9999 │
│ 3 │ τ │ 73.9998 │ 0.0141414 │ 0.000141414 │ 0.0002 │ 41.2179 │ 1.0001 |

Quantiles

│ Row │ parameters │ 2.5% │ 25.0% │ 50.0% │ 75.0% │ 97.5% │
│ 1 │ λ1 │ 5.49626 │ 5.49626 │ 5.49626 │ 5.49626 │ 5.49626 │
│ 2 │ λ2 │ 2.97791 │ 2.97791 │ 2.97791 │ 2.97791 │ 2.97791 │
│ 3 │ τ │ 74.0 │ 74.0 │ 74.0 │ 74.0 │ 74.0 │

for

sample(smscounter(y), Gibbs(PG(100, ), MH(:λ1, :λ2)), 10_000)

and

Summary Statistics

│ Row │ parameters │ mean │ std │ naive_se │ mcse │ ess │ r_hat │
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Any │ Any │
│ 1 │ λ1 │ 14.7493 │ 2.76869 │ 0.0276869 │ 0.271849 │ 40.1606 │ 1.09066 │
│ 2 │ λ2 │ 0.16302 │ 0.384425 │ 0.00384425 │ 0.0382327 │ 40.1606 │ 1.08598 │
│ 3 │ τ │ 73.9997 │ 0.0173188 │ 0.000173188 │ 0.0003 │ 41.2179 │ 1.0002 │

Quantiles

│ Row │ parameters │ 2.5% │ 25.0% │ 50.0% │ 75.0% │ 97.5% │
│ 1 │ λ1 │ 5.61827 │ 15.1975 │ 15.5239 │ 15.818 │ 16.3676 │
│ 2 │ λ2 │ 0.00163632 │ 0.0159676 │ 0.0395861 │ 0.0838861 │ 1.45111 │
│ 3 │ τ │ 74.0 │ 74.0 │ 74.0 │ 74.0 │ 74.0 │

for

sample(smscounter(y), Gibbs(PG(100, ), HMC(0.05, 10, :λ1, :λ2)), 10_000);

Which is quite far from what I'm supposed to get.

@cpfiffer
Copy link
Member

I think this is because of differences in how PyMC3 and Distributions.jl treat draws from an exponential distribution -- if you replace alpha with just the mean (and not the inverse mean), you should get something very similar to the output in the book. You can now also use MH if you'd like, which is quick-and-dirty:

using Turing

y = [13 24 8 24 7 35 14 11 15 11 22 22 11 57 11 19 29 6 19 12 22 12 18 72 32 9 7 13 19 23 27 20 6 17 13 10 14 6 16 15 7 2 15 15 19 70 49 7 53 22 21 31 19 11 18 20 12 35 17 23 17 4 2 31 30 13 27 0 39 37 5 14 13 22];

@model smscounter(y) = begin
    n = length(y)
    α = mean(y)
    λ1 ~ Exponential(α)
    λ2 ~ Exponential(α)
    τ ~ DiscreteUniform(1, n)
    
    for k in 1:τ
        y[k] ~ Poisson(λ1)
    end
    
    for k in+ 1):n
        y[k] ~ Poisson(λ2)
    end
end;

chain1 = sample(smscounter(y), MH(), 100_000);
chain2 = sample(smscounter(y), Gibbs(PG(100, ), HMC(0.05, 10, :λ1, :λ2)), 1_000);

Here's the output using MH:

Object of type Chains, with data of type 100000×4×1 Array{Union{Missing, Real},3}

Iterations        = 1:100000
Thinning interval = 1
Chains            = 1
Samples per chain = 100000
internals         = lp
parameters        = λ1, λ2, τ

2-element Array{ChainDataFrame,1}

Summary Statistics

│ Row │ parameters │ mean    │ std      │ naive_se   │ mcse      │ ess     │ r_hat   │
│     │ Symbol     │ Float64 │ Float64  │ Float64    │ Float64   │ Any     │ Any     │
├─────┼────────────┼─────────┼──────────┼────────────┼───────────┼─────────┼─────────┤
│ 1   │ λ1         │ 18.02770.7610060.002406510.023779401.6061.13277 │
│ 2   │ λ2         │ 22.66262.319520.007334980.0729886401.6061.01702 │
│ 3   │ τ          │ 45.35026.712080.02122550.205265401.6061.0116  │

Quantiles

│ Row │ parameters │ 2.5%25.0%50.0%75.0%97.5%   │
│     │ Symbol     │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │
├─────┼────────────┼─────────┼─────────┼─────────┼─────────┼─────────┤
│ 1   │ λ1         │ 16.991917.164617.992518.484319.9866 │
│ 2   │ λ2         │ 15.251422.143422.649923.729426.2475 │
│ 3   │ τ          │ 43.045.045.045.070.0

Using PG and HMC together:

Object of type Chains, with data of type 1000×4×1 Array{Union{Missing, Real},3}

Iterations        = 1:1000
Thinning interval = 1
Chains            = 1
Samples per chain = 1000
internals         = lp
parameters        = λ1, λ2, τ

2-element Array{ChainDataFrame,1}

Summary Statistics

│ Row │ parameters │ mean    │ std      │ naive_se  │ mcse      │ ess     │ r_hat    │
│     │ Symbol     │ Float64 │ Float64  │ Float64   │ Float64   │ Any     │ Any      │
├─────┼────────────┼─────────┼──────────┼───────────┼───────────┼─────────┼──────────┤
│ 1   │ λ1         │ 17.7610.8124770.02569280.01591421057.140.999353 │
│ 2   │ λ2         │ 22.67071.167250.03691180.13085412.40831.02775  │
│ 3   │ τ          │ 43.5145.08170.1606970.9548796.139111.04248  │

Quantiles

│ Row │ parameters │ 2.5%25.0%50.0%75.0%97.5%   │
│     │ Symbol     │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │
├─────┼────────────┼─────────┼─────────┼─────────┼─────────┼─────────┤
│ 1   │ λ1         │ 16.295717.217817.744218.286619.3092 │
│ 2   │ λ2         │ 20.615422.066422.701723.375224.5441 │
│ 3   │ τ          │ 37.9544.044.045.045.0

@DomagojGalic
Copy link
Author

Than you, never crossed my mind that could be the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants