Skip to content

Commit

Permalink
added srand, rng types
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Sep 29, 2017
1 parent 7133bce commit e182add
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/ParticleFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,17 @@ A particle filter that calculates relative weights for each particle based on ob
The resample field may be a function or an object that controls resampling. If it is a function `f`, `f(b, rng)` will be called. If it is an object, `o`, `resample(o, b, rng)` will be called, where `b` is a `WeightedParticleBelief`.
"""
mutable struct SimpleParticleFilter{S,R} <: Updater
mutable struct SimpleParticleFilter{S,R,RNG<:AbstractRNG} <: Updater
model
resample::R
rng::AbstractRNG
rng::RNG
_particle_memory::Vector{S}
_weight_memory::Vector{Float64}

SimpleParticleFilter{S, R}(model, resample, rng) where {S,R} = new(model, resample, rng, state_type(model)[], Float64[])
SimpleParticleFilter{S, R, RNG}(model, resample, rng) where {S,R,RNG} = new(model, resample, rng, state_type(model)[], Float64[])
end
function SimpleParticleFilter{R}(model, resample::R, rng::AbstractRNG)
SimpleParticleFilter{state_type(model),R}(model, resample, rng)
SimpleParticleFilter{state_type(model),R,typeof(rng)}(model, resample, rng)
end
SimpleParticleFilter(model, resample; rng::AbstractRNG=Base.GLOBAL_RNG) = SimpleParticleFilter(model, resample, rng)

Expand Down Expand Up @@ -161,6 +161,11 @@ function update{S}(up::SimpleParticleFilter{S}, b::ParticleCollection, a, o)
return resample(up.resample, WeightedParticleBelief{S}(pm, wm, sum(wm), nothing), up.rng)
end

function Base.srand(f::SimpleParticleFilter, seed)
srand(f.rng, seed)
return f
end


# default for non-POMDPs
state_type(model) = Any
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ include("example.jl")

p = TigerPOMDP()
filter = SIRParticleFilter(p, 100)
srand(filter, 47)
b = initialize_belief(filter, initial_state_distribution(p))
m = mode(b)
m = mean(b)
Expand Down

0 comments on commit e182add

Please sign in to comment.