From 18d6ee0275b7b7d3b67f7b17b38a453eccea76bc Mon Sep 17 00:00:00 2001 From: Jose Storopoli Date: Sun, 16 Oct 2022 05:36:54 -0300 Subject: [PATCH] Fix predictive checks in Turing 0.21 (#65) Closes #64. --- _literate/04_Turing.jl | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/_literate/04_Turing.jl b/_literate/04_Turing.jl index f33e62c1..65319d61 100644 --- a/_literate/04_Turing.jl +++ b/_literate/04_Turing.jl @@ -113,7 +113,9 @@ setprogress!(false) # hide p ~ Dirichlet(6, 1) #Each outcome of the six-sided dice has a probability p. - y ~ filldist(Categorical(p), length(y)) + for i in eachindex(y) + y[i] ~ Categorical(p) + end end; # Here we are using the [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) which @@ -134,10 +136,9 @@ sum(mean(Dirichlet(6, 1))) # Also, since the outcome of a [Categorical distribution](https://en.wikipedia.org/wiki/Categorical_distribution) is an integer # and `y` is a $N$-dimensional vector of integers we need to apply some sort of broadcasting here. -# `filldist()` is a nice Turing function which takes any univariate or multivariate distribution and returns another distribution that repeats the input distribution. -# We could also use the familiar dot `.` broadcasting operator in Julia: +# We could use the familiar dot `.` broadcasting operator in Julia: # `y .~ Categorical(p)` to signal that all elements of `y` are distributed as a Categorical distribution. -# But doing that does not allow us to do predictive checks (more on this below). So, instead we use `filldist()`. +# But doing that does not allow us to do predictive checks (more on this below). So, instead we use a `for`-loop. # ### Simulating Data @@ -240,21 +241,15 @@ savefig(joinpath(@OUTPUT, "chain.svg")); # hide prior_chain = sample(model, Prior(), 2_000); # Now we can perform predictive checks using both the prior (`prior_chain`) or posterior (`chain`) distributions. -# To draw from the prior and posterior predictive distributions we instantiate a "predictive model", *i.e.* a Turing model but with the observations set to `missing`[^missing], and then calling `predict()` on the predictive model and the previously drawn samples. +# To draw from the prior and posterior predictive distributions we instantiate a "predictive model", *i.e.* a Turing model but with the observations set to `missing`, and then calling `predict()` on the predictive model and the previously drawn samples. # First let's do the *prior* predictive check: -missing_data = Vector{Missing}(missing, length(data)) # vector of `missing` -model_missing = dice_throw(missing_data) -model_predict = DynamicPPL.Model{(:y,)}(:model_predict_missing_data, - model_missing.f, - model_missing.args, - model_missing.defaults) # instantiate the "predictive model" -prior_check = predict(model_predict, prior_chain); +missing_data = similar(data, Missing) # vector of `missing` +model_missing = dice_throw(missing_data) # instantiate the "predictive model +prior_check = predict(model_missing, prior_chain); # Here we are creating a `missing_data` object which is a `Vector` of the same length as the `data` and populated with type `missing` as values. # We then instantiate a new `dice_throw` model with the `missing_data` vector as the `data` argument. -# We proceed by instantiating a new Turing `DynamicPPL.Model` model with the `missing_data` vector as the `data` argument. -# The boilerplate around `DynamicPPL.Model` is the default arguments that a `DynamicPPL.Model` model needs to have. # Finally, we call `predict()` on the predictive model and the previously drawn samples, which in our case are the samples from the prior distribution (`prior_chain`). # Note that `predict()` returns a `Chains` object from `MCMCChains.jl`: @@ -267,7 +262,7 @@ summarystats(prior_check[:, 1:5, :]) # just the first 5 prior samples # We can do the same with `chain` for a *posterior* predictive check: -posterior_check = predict(model_predict, chain); +posterior_check = predict(model_missing, chain); summarystats(posterior_check[:, 1:5, :]) # just the first 5 posterior samples # ## Conclusion @@ -284,7 +279,6 @@ summarystats(posterior_check[:, 1:5, :]) # just the first 5 posterior samples # [^MCMC]: see [5. **Markov Chain Monte Carlo (MCMC)**](/pages/5_MCMC/). # [^visualization]: we'll cover those plots and diagnostics in [5. **Markov Chain Monte Carlo (MCMC)**](/pages/5_MCMC/). # [^workflow]: note that this workflow is a extremely simplified adaptation from the original workflow on which it was based. I suggest the reader to consult the original workflow of Gelman et al. (2020). -# [^missing]: in a real-world scenario, you'll probably want to use more than just **one** observation as a predictive check, so you should use something like `Vector{Missing}(missing, length(y))` or `fill(missing, length(y)`. # ## References