Skip to content

Commit

Permalink
Rename SIRParticleFilter to BootstrapFilter (#44)
Browse files Browse the repository at this point in the history
* did renaming

* fixed notebooks
  • Loading branch information
zsunberg authored Oct 28, 2020
1 parent f6a3156 commit b4f3ff5
Show file tree
Hide file tree
Showing 16 changed files with 160 additions and 209 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ParticleFilters"
uuid = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0"
repo = "https://github.com/JuliaPOMDP/ParticleFilters.jl"
version = "0.5.1"
version = "0.5.2"

[deps]
POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dynamics(x, u, rng) = x + u + randn(rng)
y_likelihood(x_previous, u, x, y) = pdf(Normal(), y - x)

model = ParticleFilterModel{Float64}(dynamics, y_likelihood)
pf = SIRParticleFilter(model, 10)
pf = BootstrapFilter(model, 10)
```
Then the `update` function can be used to perform a particle filter update.
```julia
Expand Down
4 changes: 2 additions & 2 deletions docs/src/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The basic particle filtering step in ParticleFilters.jl is implemented in the [`
2. Reweighting - an explicit measurement (observation) model is used to calculate a new weight
3. Resampling - a new collection of state particles is generated with particle frequencies proportional to the new weights

This is an example of [sequential importance resampling](https://en.wikipedia.org/wiki/Particle_filter#Sequential_Importance_Resampling_(SIR)), and the [`SIRParticleFilter`](@ref) constructor can be used to construct such a filter with a `model` that controls the prediction and reweighting steps, and a number of particles to create in the resampling phase.
This is an example of [sequential importance resampling](https://en.wikipedia.org/wiki/Particle_filter#Sequential_Importance_Resampling_(SIR)) using the state transition distribution as the proposal distribution, and the [`BootstrapFilter`](@ref) constructor can be used to construct such a filter with a `model` that controls the prediction and reweighting steps, and a number of particles to create in the resampling phase.

A more flexible structure for building a particle filter is the [`BasicParticleFilter`](@ref). It contains three models, one for each step:

Expand All @@ -23,7 +23,7 @@ To carry out the steps individually without the need for pre-allocating memory o
## Docstrings

```@docs
SIRParticleFilter
BootstrapFilter
BasicParticleFilter
update
predict!
Expand Down
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Additionally, an important requirement for a particle filter is efficient resamp

Dynamics and measurement models for the filters can be specified as a [`ParticleFilterModel`](@ref) or a [`POMDP`](https://github.com/JuliaPOMDP/POMDPs.jl) or a custom user-defined type.

The simplest sequential-importance-resampling Particle filter can be constructed with [`SIRParticleFilter`](@ref). [`BasicParticleFilter`](@ref) provides a more flexible structure.
The simplest Bootstrap Particle filter can be constructed with [`BootstrapFilter`](@ref). [`BasicParticleFilter`](@ref) provides a more flexible structure.

Basic setup of a model is as follows:
```julia
Expand All @@ -17,7 +17,7 @@ using ParticleFilters, Distributions
dynamics(x, u, rng) = x + u + randn(rng)
y_likelihood(x_previous, u, x, y) = pdf(Normal(), y - x)
model = ParticleFilterModel{Float64}(dynamics, y_likelihood)
pf = SIRParticleFilter(model, 10)
pf = BootstrapFilter(model, 10)
```
Then the [`update`](@ref) function can be used to perform a particle filter update.
```julia
Expand Down
263 changes: 113 additions & 150 deletions notebooks/Filtering-a-Trajectory-or-Data-Series.ipynb

Large diffs are not rendered by default.

39 changes: 14 additions & 25 deletions notebooks/Using-a-Particle-Filter-for-Feedback-Control.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,12 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"N = 1000\n",
"filter = SIRParticleFilter(model, N);"
"filter = BootstrapFilter(model, N);"
]
},
{
Expand All @@ -183,7 +183,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -193,28 +193,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Recompiling stale cache file /home/zach/.julia/compiled/v1.0/Plots/ld3vC.ji for Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]\n",
"└ @ Base loading.jl:1184\n",
"WARNING: StaticArrays.FixedSizeArrays is deprecated. Use StaticArrays directly.\n",
" likely near /home/zach/.julia/packages/Plots/Ufx0i/src/Plots.jl:8\n"
]
}
],
"outputs": [],
"source": [
"using Plots\n",
"rng = Random.GLOBAL_RNG;"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -253,7 +242,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -262,7 +251,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 17,
"metadata": {},
"outputs": [
{
Expand All @@ -278,7 +267,7 @@
"\"output.gif\""
]
},
"execution_count": 10,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -294,7 +283,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 18,
"metadata": {},
"outputs": [
{
Expand All @@ -306,7 +295,7 @@
"HTML{String}(\"<img src=\\\"output.gif\\\"/>\")"
]
},
"execution_count": 11,
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -325,15 +314,15 @@
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.0.0",
"display_name": "Julia 1.5.1",
"language": "julia",
"name": "julia-1.0"
"name": "julia-1.5"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.0.0"
"version": "1.5.1"
}
},
"nbformat": 4,
Expand Down
20 changes: 10 additions & 10 deletions notebooks/Using-a-Particle-Filter-with-POMDPs-jl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -35,14 +35,14 @@
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"pomdp = LightDark1D()\n",
"N=1000\n",
"\n",
"up = SIRParticleFilter(pomdp, N)\n",
"up = BootstrapFilter(pomdp, N)\n",
"\n",
"policy = FunctionPolicy(b->1)\n",
"b0 = POMDPModels.LDNormalStateDist(-15.0, 5.0)\n",
Expand All @@ -63,7 +63,7 @@
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -73,7 +73,7 @@
},
{
"cell_type": "code",
"execution_count": 65,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -82,7 +82,7 @@
"\"hist.gif\""
]
},
"execution_count": 65,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -106,7 +106,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -118,7 +118,7 @@
"HTML{String}(\"<img src=\\\"hist.gif\\\"/>\")"
]
},
"execution_count": 63,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -137,15 +137,15 @@
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.5.0",
"display_name": "Julia 1.5.1",
"language": "julia",
"name": "julia-1.5"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.5.0"
"version": "1.5.1"
}
},
"nbformat": 4,
Expand Down
Binary file modified notebooks/hist.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/output.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions src/ParticleFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ export
BasicParticleFilter,
ImportanceResampler,
LowVarianceResampler,
SIRParticleFilter,
UnweightedParticleFilter,
ParticleFilterModel,
PredictModel,
BootstrapFilter,
ReweightModel

export
Expand Down Expand Up @@ -57,16 +57,16 @@ export
support,
initialize_belief

# deprecated
export
SimpleParticleFilter

SIRParticleFilter

include("beliefs.jl")
include("basic.jl")
include("resamplers.jl")
include("sir.jl")
include("unweighted.jl")
include("models.jl")
include("bootstrap.jl")
include("pomdps.jl")
include("policies.jl")
include("runfilter.jl")
Expand Down
12 changes: 5 additions & 7 deletions src/sir.jl → src/bootstrap.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
SIRParticleFilter(model, n, [rng])
BootstrapFilter(model, n, [rng])
Construct a sequential importance resampling particle filter.
Construct a standard bootstrap particle filter.
The Bootstrap filter was first described in Gordon, N. J., Salmond, D. J., & Smith, A. F. M. "Novel approach to nonlinear / non-Gaussian Bayesian state estimation", with the added robustness of the LowVarianceResampler.
# Arguments
- `model`: a model for the prediction dynamics and likelihood reweighing, for example a `POMDP` or `ParticleFilterModel`
Expand All @@ -10,10 +12,6 @@ Construct a sequential importance resampling particle filter.
For a more flexible particle filter structure see [`BasicParticleFilter`](@ref).
"""
function SIRParticleFilter(model, n::Int, rng::AbstractRNG)
return BasicParticleFilter(model, LowVarianceResampler(n), n, rng)
end

function SIRParticleFilter(model, n::Int; rng::AbstractRNG=Random.GLOBAL_RNG)
function BootstrapFilter(model, n::Int, rng::AbstractRNG=Random.GLOBAL_RNG)
return BasicParticleFilter(model, LowVarianceResampler(n), n, rng)
end
3 changes: 2 additions & 1 deletion src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
@deprecate SimpleParticleFilter BasicParticleFilter
@deprecate SIRParticleFilter(model, n::Int, rng::AbstractRNG) BootstrapFilter(model, n, rng)
@deprecate SIRParticleFilter(model, n::Int; rng::AbstractRNG=Random.GLOBAL_RNG) BootstrapFilter(model, n, rng)
2 changes: 1 addition & 1 deletion src/unweighted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function update(up::UnweightedParticleFilter, b::ParticleCollection, a, o)
warn("""
Particle Depletion!
The UnweightedParticleFilter generated no particles consistent with observation $o. Consider upgrading to a SIRParticleFilter or a BasicParticleFilter or creating your own domain-specific updater.
The UnweightedParticleFilter generated no particles consistent with observation $o. Consider upgrading to a BootstrapFilter or a BasicParticleFilter or creating your own domain-specific updater.
"""
)
end
Expand Down
2 changes: 1 addition & 1 deletion test/example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ h(x, rng) = rand(rng, MvNormal(x[1:2], V))

@testset "example" begin
N = 1000
filter = SIRParticleFilter(model, N)
filter = BootstrapFilter(model, N)
Random.seed!(1)
rng = Random.GLOBAL_RNG
b = ParticleCollection([4.0*rand(4).-2.0 for i in 1:N])
Expand Down
2 changes: 1 addition & 1 deletion test/lightdark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pomdp = LightDark2D(init_dist=SymmetricNormal2([2.0, 2.0], 5.0))

policy = FunctionPolicy(s->-0.1*mean(s))

fnew = SIRParticleFilter(pomdp, 100, rng=MersenneTwister(42))
fnew = BootstrapFilter(pomdp, 100, rng=MersenneTwister(42))
ro = RolloutSimulator(rng=MersenneTwister(1), max_steps=10)
simulate(ro, pomdp, policy, fnew)

Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ include("domain_specific_resampler.jl")
struct ContinuousPOMDP <: POMDP{Float64, Float64, Float64} end
@testset "infer" begin
p = TigerPOMDP()
filter = SIRParticleFilter(p, 10000)
filter = BootstrapFilter(p, 10000)
Random.seed!(filter, 47)
b = @inferred initialize_belief(filter, initialstate(p))
@testset "sir" begin
Expand Down Expand Up @@ -64,7 +64,7 @@ struct ContinuousPOMDP <: POMDP{Float64, Float64, Float64} end
end

@testset "normal" begin
pf = SIRParticleFilter(ContinuousPOMDP(), 100)
pf = BootstrapFilter(ContinuousPOMDP(), 100)
ps = @inferred initialize_belief(pf, Normal())
end
end
Expand All @@ -75,7 +75,7 @@ POMDPs.observation(::TerminalPOMDP, a, sp) = Normal(sp)
POMDPs.transition(::TerminalPOMDP, s, a) = Deterministic(s+a)
@testset "pomdp terminal" begin
pomdp = TerminalPOMDP()
pf = SIRParticleFilter(pomdp, 100)
pf = BootstrapFilter(pomdp, 100)
bp = update(pf, initialize_belief(pf, Categorical([0.5, 0.5])), -1, 1.0)
@test all(particles(bp) .== 1)
end
Expand Down

0 comments on commit b4f3ff5

Please sign in to comment.