-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathant_ars.jl
224 lines (185 loc) · 5.9 KB
/
ant_ars.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# ### Setup
# PKG_SETUP
using Dojo
using DojoEnvironments
using Random
using LinearAlgebra
using Statistics
import LinearAlgebra.normalize
# ### Parameters and structs
rng = MersenneTwister(1)
Base.@kwdef struct HyperParameters{T}
main_loop_size::Int = 100
horizon::Int = 200
step_size::T = 0.05
n_directions::Int = 16
b::Int = 16
noise::T = 0.1
end
mutable struct Policy{T}
hp::HyperParameters{T}
θ::Matrix{T}
function Policy(input_size::Int, output_size::Int, hp::HyperParameters{T}; scale=0.2) where T
new{T}(hp, scale * randn(output_size, input_size))
end
end
mutable struct Normalizer{T}
n::Vector{T}
mean::Vector{T}
mean_diff::Vector{T}
var::Vector{T}
function Normalizer(num_inputs::Int)
n = zeros(num_inputs)
mean = zeros(num_inputs)
mean_diff = zeros(num_inputs)
var = zeros(num_inputs)
new{eltype(n)}(n, mean, mean_diff, var)
end
end
# ### Oberservation functions
function normalize(normalizer, inputs)
obs_std = sqrt.(normalizer.var)
return (inputs .- normalizer.mean) ./ obs_std
end
function observe!(normalizer, inputs)
normalizer.n .+= 1
last_mean = deepcopy(normalizer.mean)
normalizer.mean .+= (inputs .- normalizer.mean) ./ normalizer.n
normalizer.mean_diff .+= (inputs .- last_mean) .* (inputs .- normalizer.mean)
normalizer.var .= max.(1e-2, normalizer.mean_diff ./ normalizer.n)
end
# ### Environment
env = get_environment(:ant_ars;
horizon=100,
gravity=-9.81,
timestep=0.05,
dampers=50.0,
springs=25.0,
friction_coefficient=0.5,
contact_feet=true,
contact_body=true);
# ### Reset and rollout functions
function reset_state!(env)
initialize!(env, :ant)
return
end
function rollout_policy(θ::Matrix, env, normalizer::Normalizer, parameters::HyperParameters; record=false)
reset_state!(env)
rewards = 0.0
for k=1:parameters.horizon
## get state
state = DojoEnvironments.get_state(env)
x = state[1:28] # minimal state without contacts
observe!(normalizer, state)
state = normalize(normalizer, state)
action = θ * state
## single step
step!(env, x, action; record, k)
## get reward
state_after = DojoEnvironments.get_state(env)
x_pos_before = x[1]
x_pos_after = state_after[1]
forward_reward = 100 * (x_pos_after - x_pos_before) / env.mechanism.timestep
control_cost = (0.05/10 * action' * action)[1]
contact_cost = 0.0
for contact in env.mechanism.contacts
contact_cost += 0.5 * 1.0e-3 * max(-1, min(1, contact.impulses[2][1]))^2.0
end
survive_reward = 0.05
reward = forward_reward - control_cost - contact_cost + survive_reward
rewards += reward
## check for failure
if !(all(isfinite.(state_after)) && (state_after[3] >= 0.2) && (state_after[3] <= 1))
println(" failed")
break
end
end
return rewards
end
# ### Training functions
function sample_policy(policy::Policy{T}) where T
δ = [randn(size(policy.θ)) for i = 1:policy.hp.n_directions]
θp = [policy.θ + policy.hp.noise .* δ[i] for i = 1:policy.hp.n_directions]
θn = [policy.θ - policy.hp.noise .* δ[i] for i = 1:policy.hp.n_directions]
return [θp..., θn...], δ
end
function update!(policy::Policy, rollouts, σ_r)
stp = zeros(size(policy.θ))
for (r_pos, r_neg, d) in rollouts
stp += (r_pos - r_neg) * d
end
policy.θ += policy.hp.step_size * stp ./ (σ_r * policy.hp.b)
return
end
function train(env, policy::Policy{T}, normalizer::Normalizer{T}, hp::HyperParameters{T}) where T
println("Training linear policy with Augmented Random Search (ARS)\n ")
## pre-allocate for rewards
rewards = zeros(2 * hp.n_directions)
for episode = 1:hp.main_loop_size
## initialize deltas and rewards
θs, δs = sample_policy(policy)
## evaluate policies
roll_time = @elapsed begin
for k = 1:(2 * hp.n_directions)
rewards[k] = rollout_policy(θs[k], env, normalizer, hp)
end
end
## reward evaluation
r_max = [max(rewards[k], rewards[hp.n_directions + k]) for k = 1:hp.n_directions]
σ_r = std(rewards)
order = sortperm(r_max, rev = true)[1:hp.b]
rollouts = [(rewards[k], rewards[hp.n_directions + k], δs[k]) for k = order]
## policy update
update!(policy, rollouts, σ_r)
## finish, print:
println("episode $episode reward_evaluation $(mean(rewards)). Took $(roll_time) seconds")
end
return nothing
end
# ### Training
train_times = Float64[]
rewards = Float64[]
policies = Matrix{Float64}[]
N = 2
for i = 1:N
## Random policy
hp = HyperParameters(
main_loop_size=20,
horizon=100,
n_directions=6,
b=6,
step_size=0.05)
input_size = 37
output_size = 8
normalizer = Normalizer(input_size)
policy = Policy(input_size, output_size, hp)
## Train policy
train_time = @elapsed train(env, policy, normalizer, hp)
## Evaluate policy
reward = rollout_policy(policy.θ, env, normalizer, hp)
## Cache
push!(train_times, train_time)
push!(rewards, reward)
push!(policies, policy.θ)
end
# ### Training results
max_idx = sortperm(rewards, lt=Base.isgreater)
train_time_best = (train_times[max_idx])[1]
rewards_best = (rewards[max_idx])[1]
policies_best = (policies[max_idx])[1];
# ### Controller with best policy
θ = policies_best
input_size = 37
normalizer = Normalizer(input_size)
function controller!(environment, k)
state = get_state(env)
observe!(normalizer, state)
state = normalize(normalizer, state)
action = θ * state
set_input!(environment, action)
end
# ### Visualize policy
reset_state!(env)
simulate!(env, controller!; record=true)
vis = visualize(env)
render(vis)