Skip to content

Commit

Permalink
Merge pull request #7 from JuliaPOMDP/unweighted
Browse files Browse the repository at this point in the history
added unweighted
  • Loading branch information
zsunberg authored Oct 8, 2017
2 parents 19afcb9 + 398188f commit 21385e9
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/ParticleFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ export
SimpleParticleFilter,
ImportanceResampler,
LowVarianceResampler,
SIRParticleFilter
SIRParticleFilter,
UnweightedParticleFilter

export
resample,
Expand Down Expand Up @@ -49,7 +50,7 @@ abstract type AbstractParticleBelief{T} end
# DEPRECATED: remove in future release
Base.eltype{T}(::Type{AbstractParticleBelief{T}}) = T

sampletype(::Type{AbstractParticleBelief{T}}) where T = T
sampletype(::Type{B}) where B<:AbstractParticleBelief{T} where T = T

### Belief types ###

Expand Down Expand Up @@ -215,6 +216,7 @@ function SIRParticleFilter(model, n::Int; rng::AbstractRNG=Base.GLOBAL_RNG)
return SimpleParticleFilter(model, LowVarianceResampler(n), rng)
end

include("unweighted.jl")
include("beliefs.jl")
include("updater.jl")
include("resamplers.jl")
Expand Down
1 change: 0 additions & 1 deletion src/beliefs.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

n_particles(b::ParticleCollection) = length(b.particles)
particles(p::ParticleCollection) = p.particles
weighted_particles(p::ParticleCollection) = (p=>1.0/length(p.particles) for p in p.particles)
Expand Down
44 changes: 44 additions & 0 deletions src/unweighted.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
UnweightedParticleFilter
A particle filter that does not use any reweighting, but only keeps particles if the observation matches the true observation exactly. This does not require obs_weight, but it will not work well in real-world situations.
"""
struct UnweightedParticleFilter{M, RNG<:AbstractRNG} <: Updater
model::M
n::Int
rng::RNG
end

function UnweightedParticleFilter(model, n::Integer; rng=Base.GLOBAL_RNG)
return UnweightedParticleFilter(model, n, rng)
end

function update(up::UnweightedParticleFilter, b::ParticleCollection, a, o)
new = sampletype(b)[]
i = 1
while i <= up.n
s = particle(b, mod1(i, n_particles(b)))
sp, o_gen = generate_so(up.model, s, a, up.rng)
if o_gen == o
push!(new, sp)
end
i += 1
end
if isempty(new)
warn("""
Particle Depletion!
The UnweightedParticleFilter generated no particles consistent with observation $o. Consider upgrading to a SIRParticleFilter or a SimpleParticleFilter or creating your own domain-specific updater.
"""
)
end
return ParticleCollection(new)
end

function update(up::UnweightedParticleFilter, b, a, o)
return update(up, initialize_belief(up, b), a, o)
end

function initialize_belief(up::UnweightedParticleFilter, b)
return ParticleCollection(collect(rand(up.rng, b) for i in 1:up.n))
end
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,10 @@ m = mode(b)
m = mean(b)
it = iterator(b)
weighted_particles(b)

rng = MersenneTwister(47)
uf = UnweightedParticleFilter(p, 1000, rng)
ps = initialize_belief(uf, initial_state_distribution(p))
a = rand(rng, actions(p))
sp, o = generate_so(p, rand(rng, ps), a, rng)
bp = update(uf, ps, a, o)

0 comments on commit 21385e9

Please sign in to comment.