Skip to content

Commit

Permalink
added weighted_particles, clarified that particles returns an iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Sep 29, 2017
1 parent e182add commit 19afcb9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
19 changes: 18 additions & 1 deletion src/ParticleFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ export
resample,
n_particles,
particles,
weighted_particles,
weight_sum,
weight,
particle,
weights,
obs_weight

Expand Down Expand Up @@ -85,10 +87,17 @@ function n_particles end
"""
particles(b::AbstractParticleBelief)
Return a vector of the particles.
Return an iterator over the particles.
"""
function particles end

"""
weighted_particles(b::AbstractParticleBelief)
Return an iterator over particle-weight pairs.
"""
function weighted_particles end

"""
weight_sum(b::AbstractParticleBelief)
Expand All @@ -103,6 +112,14 @@ Return the weight for particle i.
"""
function weight end


"""
particle(b::AbstractParticleBelief, i)
Return particle i.
"""
function particle end

"""
obs_weight(pomdp, sp, o)
obs_weight(pomdp, a, sp, o)
Expand Down
4 changes: 4 additions & 0 deletions src/beliefs.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@

n_particles(b::ParticleCollection) = length(b.particles)
particles(p::ParticleCollection) = p.particles
weighted_particles(p::ParticleCollection) = (p=>1.0/length(p.particles) for p in p.particles)
weight_sum(::ParticleCollection) = 1.0
weight(b::ParticleCollection, i::Int) = 1.0/length(b.particles)
particle(b::ParticleCollection, i::Int) = b.particles[i]
rand(rng::AbstractRNG, b::ParticleCollection) = b.particles[rand(rng, 1:length(b.particles))]
mean(b::ParticleCollection) = sum(b.particles)/length(b.particles)

n_particles(b::WeightedParticleBelief) = length(b.particles)
particles(p::WeightedParticleBelief) = p.particles
weighted_particles(b::WeightedParticleBelief) = (b.particles[i]=>b.weights[i] for i in 1:length(b.particles))
weight_sum(b::WeightedParticleBelief) = b.weight_sum
weight(b::WeightedParticleBelief, i::Int) = b.weights[i]
particle(b::WeightedParticleBelief, i::Int) = b.particles[i]
weights(b::WeightedParticleBelief) = b.weights

function rand(rng::AbstractRNG, b::WeightedParticleBelief)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ b = initialize_belief(filter, initial_state_distribution(p))
m = mode(b)
m = mean(b)
it = iterator(b)
weighted_particles(b)

0 comments on commit 19afcb9

Please sign in to comment.