-
Notifications
You must be signed in to change notification settings - Fork 223
/
Copy pathipmcmc.jl
138 lines (113 loc) · 4.45 KB
/
ipmcmc.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
IPMCMC(n_particles::Int, n_iters::Int, n_nodes::Int, n_csmc_nodes::Int)
Particle Gibbs sampler.
Usage:
```julia
IPMCMC(100, 100, 4, 2)
```
Example:
```julia
# Define a simple Normal model with unknown mean and variance.
@model gdemo(x) = begin
s ~ InverseGamma(2,3)
m ~ Normal(0,sqrt(s))
x[1] ~ Normal(m, sqrt(s))
x[2] ~ Normal(m, sqrt(s))
return s, m
end
sample(gdemo([1.5, 2]), IPMCMC(100, 100, 4, 2))
```
"""
mutable struct IPMCMC <: InferenceAlgorithm
n_particles :: Int # number of particles used
n_iters :: Int # number of iterations
n_nodes :: Int # number of nodes running SMC and CSMC
n_csmc_nodes :: Int # number of nodes CSMC
resampler :: Function # function to resample
space :: Set # sampling space, emtpy means all
gid :: Int # group ID
IPMCMC(n1::Int, n2::Int) = new(n1, n2, 32, 16, resampleSystematic, Set(), 0)
IPMCMC(n1::Int, n2::Int, n3::Int) = new(n1, n2, n3, Int(ceil(n3/2)), resampleSystematic, Set(), 0)
IPMCMC(n1::Int, n2::Int, n3::Int, n4::Int) = new(n1, n2, n3, n4, resampleSystematic, Set(), 0)
function IPMCMC(n1::Int, n2::Int, n3::Int, n4::Int, space...)
space = isa(space, Symbol) ? Set([space]) : Set(space)
new(n1, n2, n3, n4, resampleSystematic, space, 0)
end
IPMCMC(alg::IPMCMC, new_gid::Int) = new(alg.n_particles, alg.n_iters, alg.n_nodes, alg.n_csmc_nodes, alg.resampler, alg.space, new_gid)
end
function Sampler(alg::IPMCMC)
# Create SMC and CSMC nodes
samplers = Array{Sampler}(undef, alg.n_nodes)
# Use resampler_threshold=1.0 for SMC since adaptive resampling is invalid in this setting
default_CSMC = CSMC(alg.n_particles, 1, alg.resampler, alg.space, 0)
default_SMC = SMC(alg.n_particles, alg.resampler, 1.0, false, alg.space, 0)
for i in 1:alg.n_csmc_nodes
samplers[i] = Sampler(CSMC(default_CSMC, i))
end
for i in (alg.n_csmc_nodes+1):alg.n_nodes
samplers[i] = Sampler(SMC(default_SMC, i))
end
info = Dict{Symbol, Any}()
info[:samplers] = samplers
Sampler(alg, info)
end
step(model::Function, spl::Sampler{IPMCMC}, VarInfos::Array{VarInfo}, is_first::Bool) = begin
# Initialise array for marginal likelihood estimators
log_zs = zeros(spl.alg.n_nodes)
# Run SMC & CSMC nodes
for j in 1:spl.alg.n_nodes
VarInfos[j].num_produce = 0
VarInfos[j] = step(model, spl.info[:samplers][j], VarInfos[j])
log_zs[j] = spl.info[:samplers][j].info[:logevidence][end]
end
# Resampling of CSMC nodes indices
conditonal_nodes_indices = collect(1:spl.alg.n_csmc_nodes)
unconditonal_nodes_indices = collect(spl.alg.n_csmc_nodes+1:spl.alg.n_nodes)
for j in 1:spl.alg.n_csmc_nodes
# Select a new conditional node by simulating cj
log_ksi = vcat(log_zs[unconditonal_nodes_indices], log_zs[j])
ksi = exp.(log_ksi-maximum(log_ksi))
c_j = wsample(ksi) # sample from Categorical with unormalized weights
if c_j < length(log_ksi) # if CSMC node selects another index than itself
conditonal_nodes_indices[j] = unconditonal_nodes_indices[c_j]
unconditonal_nodes_indices[c_j] = j
end
end
nodes_permutation = vcat(conditonal_nodes_indices, unconditonal_nodes_indices)
VarInfos[nodes_permutation]
end
sample(model::Function, alg::IPMCMC) = begin
spl = Sampler(alg)
# Number of samples to store
sample_n = alg.n_iters * alg.n_csmc_nodes
# Init samples
time_total = zero(Float64)
samples = Array{Sample}(undef, sample_n)
weight = 1 / sample_n
for i = 1:sample_n
samples[i] = Sample(weight, Dict{Symbol, Any}())
end
# Init parameters
VarInfos = Array{VarInfo}(undef, spl.alg.n_nodes)
for j in 1:spl.alg.n_nodes
VarInfos[j] = VarInfo()
end
n = spl.alg.n_iters
# IPMCMC steps
if PROGRESS spl.info[:progress] = ProgressMeter.Progress(n, 1, "[IPMCMC] Sampling...", 0) end
for i = 1:n
@debug "IPMCMC stepping..."
time_elapsed = @elapsed VarInfos = step(model, spl, VarInfos, i==1)
# Save each CSMS retained path as a sample
for j in 1:spl.alg.n_csmc_nodes
samples[(i-1)*alg.n_csmc_nodes+j].value = Sample(VarInfos[j], spl).value
end
time_total += time_elapsed
if PROGRESS
haskey(spl.info, :progress) && ProgressMeter.update!(spl.info[:progress], spl.info[:progress].counter + 1)
end
end
println("[IPMCMC] Finished with")
println(" Running time = $time_total;")
Chain(0, samples) # wrap the result by Chain
end