diff --git a/src/ParticleFilters.jl b/src/ParticleFilters.jl index 725578f..ec5f60a 100644 --- a/src/ParticleFilters.jl +++ b/src/ParticleFilters.jl @@ -25,8 +25,10 @@ export resample, n_particles, particles, + weighted_particles, weight_sum, weight, + particle, weights, obs_weight @@ -85,10 +87,17 @@ function n_particles end """ particles(b::AbstractParticleBelief) -Return a vector of the particles. +Return an iterator over the particles. """ function particles end +""" + weighted_particles(b::AbstractParticleBelief) + +Return an iterator over particle-weight pairs. +""" +function weighted_particles end + """ weight_sum(b::AbstractParticleBelief) @@ -103,6 +112,14 @@ Return the weight for particle i. """ function weight end + +""" + particle(b::AbstractParticleBelief, i) + +Return particle i. +""" +function particle end + """ obs_weight(pomdp, sp, o) obs_weight(pomdp, a, sp, o) diff --git a/src/beliefs.jl b/src/beliefs.jl index 006de31..6ddf012 100644 --- a/src/beliefs.jl +++ b/src/beliefs.jl @@ -1,15 +1,19 @@ n_particles(b::ParticleCollection) = length(b.particles) particles(p::ParticleCollection) = p.particles +weighted_particles(p::ParticleCollection) = (p=>1.0/length(p.particles) for p in p.particles) weight_sum(::ParticleCollection) = 1.0 weight(b::ParticleCollection, i::Int) = 1.0/length(b.particles) +particle(b::ParticleCollection, i::Int) = b.particles[i] rand(rng::AbstractRNG, b::ParticleCollection) = b.particles[rand(rng, 1:length(b.particles))] mean(b::ParticleCollection) = sum(b.particles)/length(b.particles) n_particles(b::WeightedParticleBelief) = length(b.particles) particles(p::WeightedParticleBelief) = p.particles +weighted_particles(b::WeightedParticleBelief) = (b.particles[i]=>b.weights[i] for i in 1:length(b.particles)) weight_sum(b::WeightedParticleBelief) = b.weight_sum weight(b::WeightedParticleBelief, i::Int) = b.weights[i] +particle(b::WeightedParticleBelief, i::Int) = b.particles[i] weights(b::WeightedParticleBelief) = b.weights function rand(rng::AbstractRNG, b::WeightedParticleBelief) diff --git a/test/runtests.jl b/test/runtests.jl index 012708b..c2cb042 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,3 +31,4 @@ b = initialize_belief(filter, initial_state_distribution(p)) m = mode(b) m = mean(b) it = iterator(b) +weighted_particles(b)