Skip to content
This repository has been archived by the owner on Nov 8, 2024. It is now read-only.

Commit

Permalink
Fix predictive checks in Turing 0.21 (#65)
Browse files Browse the repository at this point in the history
Closes #64.
  • Loading branch information
storopoli authored Oct 16, 2022
1 parent 9fdb662 commit 18d6ee0
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions _literate/04_Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ setprogress!(false) # hide
p ~ Dirichlet(6, 1)

#Each outcome of the six-sided dice has a probability p.
y ~ filldist(Categorical(p), length(y))
for i in eachindex(y)
y[i] ~ Categorical(p)
end
end;

# Here we are using the [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) which
Expand All @@ -134,10 +136,9 @@ sum(mean(Dirichlet(6, 1)))

# Also, since the outcome of a [Categorical distribution](https://en.wikipedia.org/wiki/Categorical_distribution) is an integer
# and `y` is a $N$-dimensional vector of integers we need to apply some sort of broadcasting here.
# `filldist()` is a nice Turing function which takes any univariate or multivariate distribution and returns another distribution that repeats the input distribution.
# We could also use the familiar dot `.` broadcasting operator in Julia:
# We could use the familiar dot `.` broadcasting operator in Julia:
# `y .~ Categorical(p)` to signal that all elements of `y` are distributed as a Categorical distribution.
# But doing that does not allow us to do predictive checks (more on this below). So, instead we use `filldist()`.
# But doing that does not allow us to do predictive checks (more on this below). So, instead we use a `for`-loop.

# ### Simulating Data

Expand Down Expand Up @@ -240,21 +241,15 @@ savefig(joinpath(@OUTPUT, "chain.svg")); # hide
prior_chain = sample(model, Prior(), 2_000);

# Now we can perform predictive checks using both the prior (`prior_chain`) or posterior (`chain`) distributions.
# To draw from the prior and posterior predictive distributions we instantiate a "predictive model", *i.e.* a Turing model but with the observations set to `missing`[^missing], and then calling `predict()` on the predictive model and the previously drawn samples.
# To draw from the prior and posterior predictive distributions we instantiate a "predictive model", *i.e.* a Turing model but with the observations set to `missing`, and then calling `predict()` on the predictive model and the previously drawn samples.
# First let's do the *prior* predictive check:

missing_data = Vector{Missing}(missing, length(data)) # vector of `missing`
model_missing = dice_throw(missing_data)
model_predict = DynamicPPL.Model{(:y,)}(:model_predict_missing_data,
model_missing.f,
model_missing.args,
model_missing.defaults) # instantiate the "predictive model"
prior_check = predict(model_predict, prior_chain);
missing_data = similar(data, Missing) # vector of `missing`
model_missing = dice_throw(missing_data) # instantiate the "predictive model
prior_check = predict(model_missing, prior_chain);

# Here we are creating a `missing_data` object which is a `Vector` of the same length as the `data` and populated with type `missing` as values.
# We then instantiate a new `dice_throw` model with the `missing_data` vector as the `data` argument.
# We proceed by instantiating a new Turing `DynamicPPL.Model` model with the `missing_data` vector as the `data` argument.
# The boilerplate around `DynamicPPL.Model` is the default arguments that a `DynamicPPL.Model` model needs to have.
# Finally, we call `predict()` on the predictive model and the previously drawn samples, which in our case are the samples from the prior distribution (`prior_chain`).

# Note that `predict()` returns a `Chains` object from `MCMCChains.jl`:
Expand All @@ -267,7 +262,7 @@ summarystats(prior_check[:, 1:5, :]) # just the first 5 prior samples

# We can do the same with `chain` for a *posterior* predictive check:

posterior_check = predict(model_predict, chain);
posterior_check = predict(model_missing, chain);
summarystats(posterior_check[:, 1:5, :]) # just the first 5 posterior samples

# ## Conclusion
Expand All @@ -284,7 +279,6 @@ summarystats(posterior_check[:, 1:5, :]) # just the first 5 posterior samples
# [^MCMC]: see [5. **Markov Chain Monte Carlo (MCMC)**](/pages/5_MCMC/).
# [^visualization]: we'll cover those plots and diagnostics in [5. **Markov Chain Monte Carlo (MCMC)**](/pages/5_MCMC/).
# [^workflow]: note that this workflow is a extremely simplified adaptation from the original workflow on which it was based. I suggest the reader to consult the original workflow of Gelman et al. (2020).
# [^missing]: in a real-world scenario, you'll probably want to use more than just **one** observation as a predictive check, so you should use something like `Vector{Missing}(missing, length(y))` or `fill(missing, length(y)`.

# ## References

Expand Down

0 comments on commit 18d6ee0

Please sign in to comment.