Skip to content
This repository has been archived by the owner on Apr 26, 2023. It is now read-only.

Commit

Permalink
added obs_weight (JuliaPOMDP/POMDPs.jl#172), fixed misspelling, remov…
Browse files Browse the repository at this point in the history
…ed CircularBuffer (#71)
  • Loading branch information
zsunberg authored May 23, 2018
1 parent 453c5bc commit 4dac41f
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/POMDPToolbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import POMDPs: Simulator, simulate
import POMDPs: action, value, solve
import POMDPs: actions, action_index, state_index, obs_index, iterator, sampletype, states, n_actions, n_states, observations, n_observations, discount, isterminal
import POMDPs: generate_sr, initial_state
import POMDPs: implemented
import Base: rand, rand!, mean, ==
import DataStructures: CircularBuffer, isfull, capacity, push!, append!

using ProgressMeter
using StatsBase
Expand Down Expand Up @@ -178,6 +178,9 @@ include("model/underlying_mdp.jl")
# tools for distributions
include("distributions/distributions_jl.jl")

export obs_weight
include("model/obs_weight.jl")

export
weighted_iterator
include("distributions/weighted_iteration.jl")
Expand Down
87 changes: 87 additions & 0 deletions src/model/obs_weight.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# obs_weight is a shortcut function for getting the relative likelihood of an observation without having to construct the observation distribution. Useful for particle filtering
# maintained by @zsunberg

"""
obs_weight(pomdp, sp, o)
obs_weight(pomdp, a, sp, o)
obs_weight(pomdp, s, a, sp, o)
Return a weight proportional to the likelihood of receiving observation o from state sp (and a and s if they are present).
This is a useful shortcut for particle filtering so that the observation distribution does not have to be represented.
"""
function obs_weight end

@generated function obs_weight(p, s, a, sp, o)
ow_impl = :(obs_weight(p, a, sp, o))
o_impl = :(pdf(observation(p, s, a, sp), o))
if implemented(obs_weight, Tuple{p, a, sp, o})
return ow_impl
elseif implemented(observation, Tuple{p, s, a, sp})
return o_impl
else
return quote
try # trick to get the compiler to put the right backedges in
$ow_impl
$o_impl
catch
throw(MethodError(obs_weight, (p,s,a,sp,o)))
end
end
end
end

@generated function obs_weight(p, a, sp, o)
ow_impl = :(obs_weight(p, sp, o))
o_impl = :(pdf(observation(p, a, sp), o))
if implemented(obs_weight, Tuple{p, sp, o})
return ow_impl
elseif implemented(observation, Tuple{p, a, sp})
return o_impl
else
return quote
try # trick to get the compiler to put the right backedges in
$ow_impl
$o_impl
catch
throw(MethodError(obs_weight, (p, a, sp, o)))
end
end
end
end

@generated function obs_weight(p, sp, o)
impl = :(pdf(observation(p, sp), o))
if implemented(observation, Tuple{p, sp})
return impl
else
return quote
try # trick to get the compiler to put the right backedges in
$impl
catch
return :(throw(MethodError(obs_weight, (p, sp, o))))
end
end
end
end

function implemented(f::typeof(obs_weight), TT::Type)
m = which(f, TT)
if length(TT.parameters) == 5
P, S, A, _, O = TT.parameters
reqs_met = implemented(observation, Tuple{P,S,A,S}) || implemented(obs_weight, Tuple{P,A,S,O})
elseif length(TT.parameters) == 4
P, A, S, O = TT.parameters
reqs_met = implemented(observation, Tuple{P,A,S}) || implemented(obs_weight, Tuple{P,S,O})
elseif length(TT.parameters) == 3
P, S, O = TT.parameters
reqs_met = implemented(observation, Tuple{P,S})
else
return method_exists(f, TT)
end
if m.module == POMDPToolbox && !reqs_met
return false
else
true
end
end
2 changes: 1 addition & 1 deletion src/simulators/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ function run_parallel(process::Function, queue::AbstractVector;
warn("""
run_parallel(...) was started with only 1 process, so simulations will be run in serial.
To supress this warning, use run_parallel(..., proc_warn=false).
To suppress this warning, use run_parallel(..., proc_warn=false).
To use multiple processes, use addprocs() or the -p option (e.g. julia -p 4).
""")
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ include("test_info.jl")
include("test_k_previous_observations_belief.jl")
include("test_fully_observable_pomdp.jl")
include("test_underlying_mdp.jl")
include("test_obs_weight.jl")
18 changes: 18 additions & 0 deletions test/test_obs_weight.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import POMDPToolbox: obs_weight
import POMDPs: observation

struct P <: POMDP{Void, Void, Void} end

@test !@implemented obs_weight(::P, ::Void, ::Void, ::Void, ::Void)
@test !@implemented obs_weight(::P, ::Void, ::Void, ::Void)
@test !@implemented obs_weight(::P, ::Void, ::Void)

obs_weight(::P, ::Void, ::Void, ::Void) = 1.0
@test @implemented obs_weight(::P, ::Void, ::Void, ::Void)
@test @implemented obs_weight(::P, ::Void, ::Void, ::Void, ::Void)
@test !@implemented obs_weight(::P, ::Void, ::Void)

@test obs_weight(P(), nothing, nothing, nothing, nothing) == 1.0

observation(::P, ::Void) = nothing
@test @implemented obs_weight(::P, ::Void, ::Void)

0 comments on commit 4dac41f

Please sign in to comment.