From 435d745ec073b9fdd089282b6efd8588a97bd7af Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Tue, 13 Nov 2018 11:12:05 -0800 Subject: [PATCH] fixed error in BasicPOMCP minimal example --- src/pomdps.jl | 41 +++++++++++++++++++++++++++++------------ test/runtests.jl | 13 +++++++++---- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/pomdps.jl b/src/pomdps.jl index ada8e79..d6712ed 100644 --- a/src/pomdps.jl +++ b/src/pomdps.jl @@ -52,18 +52,35 @@ end function initialize_belief(up::BasicParticleFilter, d::D) where D # using weighted iterator here is more likely to be order n than just calling rand() repeatedly - # but, this implementation may change in the future - if @implemented(support(::D)) && @implemented(pdf(::D, ::typeof(first(support(d))))) - # if @implemented(weighted_iterator(::D)) - S = typeof(first(support(d))) - particles = S[] - weights = Float64[] - for (s, w) in weighted_iterator(d) - push!(particles, s) - push!(weights, w) + # but, this implementation is problematic and may change in the future + try + if @implemented(support(::D)) && + @implemented(iterate(::typeof(support(d)))) && + @implemented(pdf(::D, ::typeof(first(support(d))))) + S = typeof(first(support(d))) + particles = S[] + weights = Float64[] + for (s, w) in weighted_iterator(d) + push!(particles, s) + push!(weights, w) + end + return resample(ImportanceResampler(up.n_init), WeightedParticleBelief(particles, weights), up.rng) + end + catch ex + if ex isa MethodError + @warn(""" + Suppressing MethodError in initialize_belief in ParticleFilters.jl. Please file an issue here: + + https://github.com/JuliaPOMDP/ParticleFilters.jl/issues/new + + The error was + + $(sprint(showerror, ex)) + """, maxlog=1) + else + rethrow(ex) end - return resample(ImportanceResampler(up.n_init), WeightedParticleBelief(particles, weights), up.rng) - else - return ParticleCollection(collect(rand(up.rng, d) for i in 1:up.n_init)) end + + return ParticleCollection(collect(rand(up.rng, d) for i in 1:up.n_init)) end diff --git a/test/runtests.jl b/test/runtests.jl index fd8609e..d4c0a4c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,8 +5,7 @@ using Test using POMDPPolicies using POMDPSimulators using Random -import ParticleFilters: obs_weight -import POMDPs: observation +using Distributions using NBInclude struct P <: POMDP{Nothing, Nothing, Nothing} end @@ -15,7 +14,7 @@ struct P <: POMDP{Nothing, Nothing, Nothing} end @test !@implemented obs_weight(::P, ::Nothing, ::Nothing, ::Nothing) @test !@implemented obs_weight(::P, ::Nothing, ::Nothing) end -obs_weight(::P, ::Nothing, ::Nothing, ::Nothing) = 1.0 +ParticleFilters.obs_weight(::P, ::Nothing, ::Nothing, ::Nothing) = 1.0 @testset "implemented" begin @test @implemented obs_weight(::P, ::Nothing, ::Nothing, ::Nothing) @@ -24,12 +23,13 @@ obs_weight(::P, ::Nothing, ::Nothing, ::Nothing) = 1.0 @test obs_weight(P(), nothing, nothing, nothing, nothing) == 1.0 end -observation(::P, ::Nothing) = nothing +POMDPs.observation(::P, ::Nothing) = nothing @test @implemented obs_weight(::P, ::Nothing, ::Nothing) include("example.jl") include("domain_specific_resampler.jl") +struct ContinuousPOMDP <: POMDP{Float64, Float64, Float64} end @testset "infer" begin p = TigerPOMDP() filter = SIRParticleFilter(p, 10000) @@ -70,6 +70,11 @@ include("domain_specific_resampler.jl") wp2 = @inferred collect(weighted_particles(WeightedParticleBelief([1,2], [0.5, 0.5]))) @test wp1 == wp2 end + + @testset "normal" begin + pf = SIRParticleFilter(ContinuousPOMDP(), 100) + ps = @inferred initialize_belief(pf, Normal()) + end end @testset "alpha" begin