Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update script.jl #79

Merged
merged 3 commits into from
Sep 17, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions examples/particle-gibbs/script.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# # Particle Gibbs for non-linear models
using AdvancedPS
using Random
using Distributions
Expand All @@ -13,34 +12,33 @@ using Libtask
# x_{t+1} = a x_t + v_t \quad v_{t} \sim \mathcal{N}(0, r^2)
# ```
# ```math
# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad v_{t} \sim \mathcal{N}(0, 1)
# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad e_t \sim \mathcal{N}(0, 1)
# ```
#
# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$.
# We can reformulate the above in terms of transition and observation densities:
# ```math
# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, q^2)
# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, r^2)
# ```
# ```math
# y_t \sim g_{\theta}(y_t|x_t) = \mathcal{N}(0, \exp(\frac{1}{2}x_t)^2)
# ```
# with the initial distribution $f_0(x) = \mathcal{N}(0, q^2)$.
# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$.
Parameters = @NamedTuple begin
a::Float64
q::Float64
T::Int
end

mutable struct NonLinearTimeSeries <: AbstractMCMC.AbstractModel
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
X::Array
θ::Parameters
NonLinearTimeSeries(θ::Parameters) = new(zeros(Float64, θ.T), θ)
end

f(model::NonLinearTimeSeries, state, t) = Normal(model.θ.a * state, model.θ.q)
g(model::NonLinearTimeSeries, state, t) = Normal(0, exp(0.5 * state)^2)
g(model::NonLinearTimeSeries, state, t) = Normal(0, exp(0.5 * state))
f₀(model::NonLinearTimeSeries) = Normal(0, model.θ.q)
#md nothing #hide

# Let's simulate some data
a = 0.9 # State Variance
Expand Down Expand Up @@ -88,8 +86,8 @@ end

# Here we use the particle gibbs kernel without adaptive resampling.
model = NonLinearTimeSeries(θ₀)
pgas = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pgas, Nₛ; progress=false);
Copy link
Member

@FredericWantiez FredericWantiez Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment, but why removing the rng argument ? That makes the documentation reproducible (to some extent, rng implementations can change with julia versions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, removed that when trying to get pgas to work. I edited the line now so that it is back to using rng as first argument, and also progress=false as in the original example.

pg = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pg, Nₛ; progress=false);
#md nothing #hide

# The trajectories are not stored during the sampling and we need to regenerate the history of each
Expand All @@ -116,8 +114,7 @@ mean_trajectory = mean(particles; dims=2)
#md nothing #hide

# We can now plot all the generated traces.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help
# with the degeneracy problem.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below.
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
Expand All @@ -137,3 +134,38 @@ plot(
ylabel="Update rate",
)
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")

# Let's see if ancestor sampling can help with the degeneracy problem. We use the same number of particles, but replace the sampler with PGAS.
# To use this sampler we need to define the transition and observation densities as well as the initial distribution in the following way:
AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model)
AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model, state, step)
function AdvancedPS.observation(model::NonLinearTimeSeries, state, step)
return logpdf(g(model, state, step), y[step])
end
AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ

# We can now sample from the model using the PGAS sampler and collect the trajectories.
pg = AdvancedPS.PGAS(Nₚ)
chains = sample(model, pg, Nₛ);
trajectories = map(chains) do sample
replay(sample.trajectory)
end
particles = hcat([trajectory.model.f.X for trajectory in trajectories]...)
mean_trajectory = mean(particles; dims=2)

# The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods.
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)

# The update rate is now much higher throughout time.
update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ
plot(
update_rate;
label=false,
ylim=[0, 1],
legend=:bottomleft,
xlabel="Iteration",
ylabel="Update rate",
)
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")