-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrepresentation_lightdark.jl
65 lines (56 loc) · 2.17 KB
/
representation_lightdark.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
using BetaZero
using LightDark
using ParticleBeliefs
using ParticleFilters
using Statistics
using StatsBase
using POMDPs
USE_LIGHTDARK_5 = false
if USE_LIGHTDARK_5
# Old light dark with light region around 5: LightDark(5)
pomdp = LightDarkPOMDP(; light_loc=5, sigma = y->abs(y - 5)/sqrt(2) + 1e-2, correct_r=10, incorrect_r=-10)
@warn "Using old light dark!"
else
# LightDark(10)
pomdp = LightDarkPOMDP()
end
up = ParticleHistoryBeliefUpdater(BootstrapFilter(pomdp, 500))
function BetaZero.input_representation(b::ParticleHistoryBelief{LightDarkState};
include_std::Bool=true, # Important to capture uncertainty in belief.
use_higher_orders::Bool=false,
include_action::Bool=false,
include_obs::Bool=false)
Y = [s.y for s in ParticleFilters.particles(b)]
μ = mean(Y)
σ = std(Y)
local b̃
if use_higher_orders
zeroifnan(x) = isnan(x) ? 0 : x
s = zeroifnan(skewness(Y))
k = zeroifnan(kurtosis(Y))
b̃ = Float32[μ, σ, s, k]
else
if include_std
b̃ = Float32[μ, σ]
else
b̃ = Float32[μ]
end
end
if include_obs
o = isempty(b.observations) ? 0.f0 : b.observations[end]
b̃ = [b̃..., o]
end
if include_action
a = isempty(b.actions) ? -999 : b.actions[end]
b̃ = [b̃..., a]
end
return b̃
end
# BetaZero.optimal_return(pomdp::LightDarkPOMDP, s) = pomdp.correct_r
BetaZero.accuracy(pomdp::LightDarkPOMDP, b0, s0, states, actions, returns) = returns[end] == pomdp.correct_r
lightdark_belief_reward(pomdp, b, a, bp) = mean(reward(pomdp, s, a) for s in ParticleFilters.particles(b))
POMDPs.convert_s(::Type{A}, b::ParticleHistoryBelief{LightDarkState}, m::BeliefMDP) where A<:AbstractArray = eltype(A)[BetaZero.input_representation(b)...]
POMDPs.convert_s(::Type{ParticleHistoryBelief{LightDarkState}}, b::A, m::BeliefMDP) where A<:AbstractArray = ParticleHistoryBelief(particles=ParticleCollection(rand(LDNormalStateDist(b[1], b[2]), up.pf.n_init)))
function Statistics.mean(b::ParticleHistoryBelief{LightDarkState})
return LightDarkState(mean(s->s.status, particles(b)), mean(s->s.y, particles(b)))
end