-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathreadme_example.jl
40 lines (35 loc) · 1.42 KB
/
readme_example.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
using BetaZero
using LightDark
pomdp = LightDarkPOMDP()
up = BootstrapFilter(pomdp, 500)
function BetaZero.input_representation(b::ParticleCollection{LightDarkState})
# Function to get belief representation as input to neural network.
μ, σ = mean_and_std(s.y for s in particles(b))
return Float32[μ, σ]
end
function BetaZero.accuracy(pomdp::LightDarkPOMDP, b0, s0, states, actions, returns)
# Function to determine accuracy of agent's final decision.
return returns[end] == pomdp.correct_r
end
solver = BetaZeroSolver(pomdp=pomdp,
updater=up,
params=BetaZeroParameters(
n_iterations=50,
n_data_gen=50,
),
nn_params=BetaZeroNetworkParameters(
pomdp, up;
training_epochs=50,
n_samples=100_000,
batchsize=1024,
learning_rate=1e-4,
λ_regularization=1e-5,
use_dropout=true,
p_dropout=0.2,
),
verbose=true,
collect_metrics=true,
plot_incremental_data_gen=true)
policy = solve(solver, pomdp)
save_policy(policy, "policy.bson")
save_solver(solver, "solver.bson")