From ad1ab3eb89286fbfc19106e4fb99ccd2bbb38407 Mon Sep 17 00:00:00 2001 From: Mattias Villani Date: Mon, 4 Sep 2023 23:02:58 +0200 Subject: [PATCH 1/3] Update script.jl Fixed some typos in the model and one in the code (variance was used instead of standard deviation in measurement model). Added PGAS sampling at the end to show that it solves the degeneracy problem, which should close the issue https://github.com/TuringLang/AdvancedPS.jl/issues/77 --- examples/particle-gibbs/script.jl | 54 ++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/examples/particle-gibbs/script.jl b/examples/particle-gibbs/script.jl index 9b084808..5bb61f4c 100644 --- a/examples/particle-gibbs/script.jl +++ b/examples/particle-gibbs/script.jl @@ -1,4 +1,3 @@ -# # Particle Gibbs for non-linear models using AdvancedPS using Random using Distributions @@ -13,34 +12,33 @@ using Libtask # x_{t+1} = a x_t + v_t \quad v_{t} \sim \mathcal{N}(0, r^2) # ``` # ```math -# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad v_{t} \sim \mathcal{N}(0, 1) +# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad e_t \sim \mathcal{N}(0, 1) # ``` # -# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$. # We can reformulate the above in terms of transition and observation densities: # ```math -# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, q^2) +# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, r^2) # ``` # ```math # y_t \sim g_{\theta}(y_t|x_t) = \mathcal{N}(0, \exp(\frac{1}{2}x_t)^2) # ``` # with the initial distribution $f_0(x) = \mathcal{N}(0, q^2)$. +# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$. Parameters = @NamedTuple begin a::Float64 q::Float64 T::Int end -mutable struct NonLinearTimeSeries <: AbstractMCMC.AbstractModel +mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel X::Array θ::Parameters NonLinearTimeSeries(θ::Parameters) = new(zeros(Float64, θ.T), θ) end f(model::NonLinearTimeSeries, state, t) = Normal(model.θ.a * state, model.θ.q) -g(model::NonLinearTimeSeries, state, t) = Normal(0, exp(0.5 * state)^2) +g(model::NonLinearTimeSeries, state, t) = Normal(0, exp(0.5 * state)) f₀(model::NonLinearTimeSeries) = Normal(0, model.θ.q) -#md nothing #hide # Let's simulate some data a = 0.9 # State Variance @@ -88,8 +86,8 @@ end # Here we use the particle gibbs kernel without adaptive resampling. model = NonLinearTimeSeries(θ₀) -pgas = AdvancedPS.PG(Nₚ, 1.0) -chains = sample(rng, model, pgas, Nₛ; progress=false); +pg = AdvancedPS.PG(Nₚ, 1.0) +chains = sample(model, pg, Nₛ); #md nothing #hide # The trajectories are not stored during the sampling and we need to regenerate the history of each @@ -116,8 +114,7 @@ mean_trajectory = mean(particles; dims=2) #md nothing #hide # We can now plot all the generated traces. -# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help -# with the degeneracy problem. +# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below. scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state") plot!(x; color=:darkorange, label="Original Trajectory") plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9) @@ -137,3 +134,38 @@ plot( ylabel="Update rate", ) hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)") + +# Let's see if ancestor sampling can help with the degeneracy problem. We use the same number of particles, but replace the sampler with PGAS. +# To use this sampler we need to define the transition and observation densities as well as the initial distribution in the following way: +AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model) +AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model, state, step) +function AdvancedPS.observation(model::NonLinearTimeSeries, state, step) + return logpdf(g(model, state, step), y[step]) +end +AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ + +# We can now sample from the model using the PGAS sampler and collect the trajectories. +pg = AdvancedPS.PGAS(Nₚ) +chains = sample(model, pg, Nₛ); +trajectories = map(chains) do sample + replay(sample.trajectory) +end +particles = hcat([trajectory.model.f.X for trajectory in trajectories]...) +mean_trajectory = mean(particles; dims=2) + +# The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods. +scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state") +plot!(x; color=:darkorange, label="Original Trajectory") +plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9) + +# The update rate is now much higher throughout time. +update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ +plot( + update_rate; + label=false, + ylim=[0, 1], + legend=:bottomleft, + xlabel="Iteration", + ylabel="Update rate", +) +hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)") From 2f9b66d1ba15a4602caee89636abc0a67d9cfa9a Mon Sep 17 00:00:00 2001 From: Mattias Villani Date: Tue, 5 Sep 2023 10:48:50 +0200 Subject: [PATCH 2/3] Update script.jl added back the rng argument when sampling with PG. --- examples/particle-gibbs/script.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/particle-gibbs/script.jl b/examples/particle-gibbs/script.jl index 5bb61f4c..e72e1fc7 100644 --- a/examples/particle-gibbs/script.jl +++ b/examples/particle-gibbs/script.jl @@ -87,7 +87,7 @@ end # Here we use the particle gibbs kernel without adaptive resampling. model = NonLinearTimeSeries(θ₀) pg = AdvancedPS.PG(Nₚ, 1.0) -chains = sample(model, pg, Nₛ); +chains = sample(rng, model, pg, Nₛ; progress=false); #md nothing #hide # The trajectories are not stored during the sampling and we need to regenerate the history of each From fb2203b49affa664afcc5a41d92d9c69d05d2dc0 Mon Sep 17 00:00:00 2001 From: Mattias Villani Date: Fri, 15 Sep 2023 11:13:05 +0200 Subject: [PATCH 3/3] Update examples/particle-gibbs/script.jl Co-authored-by: FredericWantiez --- examples/particle-gibbs/script.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/particle-gibbs/script.jl b/examples/particle-gibbs/script.jl index e72e1fc7..de955fc8 100644 --- a/examples/particle-gibbs/script.jl +++ b/examples/particle-gibbs/script.jl @@ -147,9 +147,6 @@ AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ # We can now sample from the model using the PGAS sampler and collect the trajectories. pg = AdvancedPS.PGAS(Nₚ) chains = sample(model, pg, Nₛ); -trajectories = map(chains) do sample - replay(sample.trajectory) -end particles = hcat([trajectory.model.f.X for trajectory in trajectories]...) mean_trajectory = mean(particles; dims=2)