Skip to content
This repository has been archived by the owner on Nov 8, 2024. It is now read-only.

Fix predictive checks in Turing 0.21 #65

Merged
merged 1 commit into from
Oct 16, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions _literate/04_Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ setprogress!(false) # hide
p ~ Dirichlet(6, 1)

#Each outcome of the six-sided dice has a probability p.
y ~ filldist(Categorical(p), length(y))
for i in eachindex(y)
y[i] ~ Categorical(p)
end
end;

# Here we are using the [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) which
Expand All @@ -134,10 +136,9 @@ sum(mean(Dirichlet(6, 1)))

# Also, since the outcome of a [Categorical distribution](https://en.wikipedia.org/wiki/Categorical_distribution) is an integer
# and `y` is a $N$-dimensional vector of integers we need to apply some sort of broadcasting here.
# `filldist()` is a nice Turing function which takes any univariate or multivariate distribution and returns another distribution that repeats the input distribution.
# We could also use the familiar dot `.` broadcasting operator in Julia:
# We could use the familiar dot `.` broadcasting operator in Julia:
# `y .~ Categorical(p)` to signal that all elements of `y` are distributed as a Categorical distribution.
# But doing that does not allow us to do predictive checks (more on this below). So, instead we use `filldist()`.
# But doing that does not allow us to do predictive checks (more on this below). So, instead we use a `for`-loop.

# ### Simulating Data

Expand Down Expand Up @@ -240,21 +241,15 @@ savefig(joinpath(@OUTPUT, "chain.svg")); # hide
prior_chain = sample(model, Prior(), 2_000);

# Now we can perform predictive checks using both the prior (`prior_chain`) or posterior (`chain`) distributions.
# To draw from the prior and posterior predictive distributions we instantiate a "predictive model", *i.e.* a Turing model but with the observations set to `missing`[^missing], and then calling `predict()` on the predictive model and the previously drawn samples.
# To draw from the prior and posterior predictive distributions we instantiate a "predictive model", *i.e.* a Turing model but with the observations set to `missing`, and then calling `predict()` on the predictive model and the previously drawn samples.
# First let's do the *prior* predictive check:

missing_data = Vector{Missing}(missing, length(data)) # vector of `missing`
model_missing = dice_throw(missing_data)
model_predict = DynamicPPL.Model{(:y,)}(:model_predict_missing_data,
model_missing.f,
model_missing.args,
model_missing.defaults) # instantiate the "predictive model"
prior_check = predict(model_predict, prior_chain);
missing_data = similar(data, Missing) # vector of `missing`
model_missing = dice_throw(missing_data) # instantiate the "predictive model
prior_check = predict(model_missing, prior_chain);

# Here we are creating a `missing_data` object which is a `Vector` of the same length as the `data` and populated with type `missing` as values.
# We then instantiate a new `dice_throw` model with the `missing_data` vector as the `data` argument.
# We proceed by instantiating a new Turing `DynamicPPL.Model` model with the `missing_data` vector as the `data` argument.
# The boilerplate around `DynamicPPL.Model` is the default arguments that a `DynamicPPL.Model` model needs to have.
# Finally, we call `predict()` on the predictive model and the previously drawn samples, which in our case are the samples from the prior distribution (`prior_chain`).

# Note that `predict()` returns a `Chains` object from `MCMCChains.jl`:
Expand All @@ -267,7 +262,7 @@ summarystats(prior_check[:, 1:5, :]) # just the first 5 prior samples

# We can do the same with `chain` for a *posterior* predictive check:

posterior_check = predict(model_predict, chain);
posterior_check = predict(model_missing, chain);
summarystats(posterior_check[:, 1:5, :]) # just the first 5 posterior samples

# ## Conclusion
Expand All @@ -284,7 +279,6 @@ summarystats(posterior_check[:, 1:5, :]) # just the first 5 posterior samples
# [^MCMC]: see [5. **Markov Chain Monte Carlo (MCMC)**](/pages/5_MCMC/).
# [^visualization]: we'll cover those plots and diagnostics in [5. **Markov Chain Monte Carlo (MCMC)**](/pages/5_MCMC/).
# [^workflow]: note that this workflow is a extremely simplified adaptation from the original workflow on which it was based. I suggest the reader to consult the original workflow of Gelman et al. (2020).
# [^missing]: in a real-world scenario, you'll probably want to use more than just **one** observation as a predictive check, so you should use something like `Vector{Missing}(missing, length(y))` or `fill(missing, length(y)`.

# ## References

Expand Down