Skip to content

Commit

Permalink
simplified obs_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jun 24, 2020
1 parent 02329bb commit 980076d
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 121 deletions.
14 changes: 2 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.2.4"
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -14,18 +15,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

[compat]
Distributions = "0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23"
POMDPModels = ">= 0.4.3"
POMDPs = "0.9"
UnicodePlots = "1"
julia = "1"

[extras]
BeliefUpdaters = "8bb6e9a1-7d73-552c-a44a-e5dc5634aac4"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPPolicies = "182e52fb-cfd0-5e46-8c26-fd0667c990f4"
POMDPSimulators = "e0d0a172-29c6-5d4e-96d0-f262df5d01fd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "POMDPModels", "POMDPSimulators", "POMDPPolicies", "BeliefUpdaters", "Pkg"]
POMDPLinter = "0.1"
4 changes: 3 additions & 1 deletion src/POMDPModelTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ using UnicodePlots
import POMDPs: actions, actionindex
import POMDPs: states, stateindex
import POMDPs: observations, obsindex
import POMDPs: sampletype, generate_sr, initialstate, isterminal, discount
import POMDPs: initialstate, isterminal, discount
import POMDPs: implemented
import Distributions: pdf, mode, mean, support
import Random: rand, rand!
import Statistics: mean
import Base: ==

import POMDPLinter: @POMDP_require

export
render
include("visualization.jl")
Expand Down
23 changes: 0 additions & 23 deletions src/fully_observable_pomdp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,12 @@ end

mdptype(::Type{FullyObservablePOMDP{M,S,A}}) where {M,S,A} = M

function POMDPs.DDNStructure(::Type{M}) where M <: FullyObservablePOMDP
MM = mdptype(M)
add_obsnode(DDNStructure(MM))
end

add_obsnode(ddn) = add_node(ddn, :o, FunctionDDNNode((m,sp)->sp), (:sp,)) # for ::DDNStructure, but this is not declared yet POMDPs in v0.7.3

POMDPs.observations(pomdp::FullyObservablePOMDP) = states(pomdp.mdp)
POMDPs.obsindex(pomdp::FullyObservablePOMDP{S, A}, o::S) where {S, A} = stateindex(pomdp.mdp, o)

POMDPs.convert_o(T::Type{V}, o, pomdp::FullyObservablePOMDP) where {V<:AbstractArray} = convert_s(T, s, pomdp.mdp)
POMDPs.convert_o(T::Type{S}, vec::V, pomdp::FullyObservablePOMDP) where {S,V<:AbstractArray} = convert_s(T, vec, pomdp.mdp)

POMDPs.gen(::DDNNode{:o}, m::FullyObservablePOMDP, sp, rng) = sp

function POMDPs.observation(pomdp::FullyObservablePOMDP, a, sp)
return Deterministic(sp)
end
Expand All @@ -52,18 +43,4 @@ POMDPs.convert_s(T::Type{S}, vec::V, pomdp::FullyObservablePOMDP) where {S,V<:Ab
POMDPs.convert_a(T::Type{V}, a, pomdp::FullyObservablePOMDP) where V<:AbstractArray = convert_a(T, a, pomdp.mdp)
POMDPs.convert_a(T::Type{A}, vec::V, pomdp::FullyObservablePOMDP) where {A,V<:AbstractArray} = convert_a(T, vec, pomdp.mdp)

POMDPs.gen(d::DDNNode, m::FullyObservablePOMDP, args...) = gen(d, m.mdp, args...)
POMDPs.gen(m::FullyObservablePOMDP, s, a, rng) = gen(m.mdp, s, a, rng)
POMDPs.reward(pomdp::FullyObservablePOMDP, s, a) = reward(pomdp.mdp, s, a)

# deprecated in POMDPs v0.8
add_obsnode(ddn::POMDPs.DDNStructureV7{(:s,:a,:sp,:r)}) = POMDPs.DDNStructureV7{(:s,:a,:sp,:o,:r)}()
add_obsnode(ddn::POMDPs.DDNStructureV7) = error("FullyObservablePOMDP only supports MDPs with the standard DDN Structure (DDNStructureV7{(:s,:a,:sp,:r)}) with POMDPs v0.7.")

POMDPs.generate_s(pomdp::FullyObservablePOMDP, s, a, rng::AbstractRNG) = generate_s(pomdp.mdp, s, a, rng)
POMDPs.generate_sr(pomdp::FullyObservablePOMDP, s, a, rng::AbstractRNG) = generate_sr(pomdp.mdp, s, a, rng)
POMDPs.n_actions(pomdp::FullyObservablePOMDP) = n_actions(pomdp.mdp)
POMDPs.n_states(pomdp::FullyObservablePOMDP) = n_states(pomdp.mdp)
function POMDPs.generate_o(pomdp::FullyObservablePOMDP, s, rng::AbstractRNG)
return s
end
78 changes: 1 addition & 77 deletions src/obs_weight.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,86 +2,10 @@
# maintained by @zsunberg

"""
obs_weight(pomdp, sp, o)
obs_weight(pomdp, a, sp, o)
obs_weight(pomdp, s, a, sp, o)
Return a weight proportional to the likelihood of receiving observation o from state sp (and a and s if they are present).
This is a useful shortcut for particle filtering so that the observation distribution does not have to be represented.
"""
function obs_weight end

@generated function obs_weight(p, s, a, sp, o)
ow_impl = :(obs_weight(p, a, sp, o))
o_impl = :(pdf(observation(p, s, a, sp), o))
if implemented(obs_weight, Tuple{p, a, sp, o})
return ow_impl
elseif implemented(observation, Tuple{p, s, a, sp})
return o_impl
else
return quote
try # trick to get the compiler to put the right backedges in
$ow_impl
$o_impl
catch
throw(MethodError(obs_weight, (p,s,a,sp,o)))
end
end
end
end

@generated function obs_weight(p, a, sp, o)
ow_impl = :(obs_weight(p, sp, o))
o_impl = :(pdf(observation(p, a, sp), o))
if implemented(obs_weight, Tuple{p, sp, o})
return ow_impl
elseif implemented(observation, Tuple{p, a, sp})
return o_impl
else
return quote
try # trick to get the compiler to put the right backedges in
$ow_impl
$o_impl
catch
throw(MethodError(obs_weight, (p, a, sp, o)))
end
end
end
end

@generated function obs_weight(p, sp, o)
impl = :(pdf(observation(p, sp), o))
if implemented(observation, Tuple{p, sp})
return impl
else
return quote
try # trick to get the compiler to put the right backedges in
$impl
catch
return :(throw(MethodError(obs_weight, (p, sp, o))))
end
end
end
end

function implemented(f::typeof(obs_weight), TT::Type)
m = which(f, TT)
if length(TT.parameters) == 5
P, S, A, _, O = TT.parameters
reqs_met = implemented(observation, Tuple{P,S,A,S}) || implemented(obs_weight, Tuple{P,A,S,O})
elseif length(TT.parameters) == 4
P, A, S, O = TT.parameters
reqs_met = implemented(observation, Tuple{P,A,S}) || implemented(obs_weight, Tuple{P,S,O})
elseif length(TT.parameters) == 3
P, S, O = TT.parameters
reqs_met = implemented(observation, Tuple{P,S})
else
return method_exists(f, TT)
end
if m.module == POMDPModelTools && !reqs_met
return false
else
true
end
end
obs_weight(p, s, a, sp, o) = pdf(observation(p, s, a, sp), o)
8 changes: 0 additions & 8 deletions src/underlying_mdp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,3 @@ POMDPs.actionindex(mdp::UnderlyingMDP{P,S, Int}, a::Int) where {P,S} = actionind
POMDPs.actionindex(mdp::UnderlyingMDP{P,S, Bool}, a::Bool) where {P,S} = actionindex(mdp.pomdp, a)

POMDPs.gen(mdp::UnderlyingMDP, s, a, rng) = gen(d, mdp.pomdp, s, a, rng)
POMDPs.gen(mdp::UnderlyingMDP, s, a, rng) = gen(d, mdp.pomdp, s, a, rng)
POMDPs.gen(mdp::UnderlyingMDP, s, a, rng) = gen(m.pomdp, s, a, rng)

# deprecated in POMDPs v0.8
POMDPs.n_actions(mdp::UnderlyingMDP) = n_actions(mdp.pomdp)
POMDPs.n_states(mdp::UnderlyingMDP) = n_states(mdp.pomdp)
POMDPs.generate_s(mdp::UnderlyingMDP, s, a, rng::AbstractRNG) = generate_s(mdp.pomdp, s, a, rng)
POMDPs.generate_sr(mdp::UnderlyingMDP, s, a, rng::AbstractRNG) = generate_sr(mdp.pomdp, s, a, rng)

0 comments on commit 980076d

Please sign in to comment.