Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiahpslewis authored Jun 9, 2023
2 parents 04ecd28 + ea00fdf commit 32604f2
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 24 deletions.
9 changes: 7 additions & 2 deletions src/ReinforcementLearningCore/src/core/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,31 +89,36 @@ function _run(policy::AbstractPolicy,
while !is_stop
reset!(env)
push!(policy, PreEpisodeStage(), env)
optimise!(policy, PreActStage())
push!(hook, PreEpisodeStage(), policy, env)


while !reset_condition(policy, env) # one episode
push!(policy, PreActStage(), env)
optimise!(policy, PreActStage())
push!(hook, PreActStage(), policy, env)

action = RLBase.plan!(policy, env)
act!(env, action)

optimise!(policy)

push!(policy, PostActStage(), env)
optimise!(policy, PostActStage())
push!(hook, PostActStage(), policy, env)

if check_stop(stop_condition, policy, env)
is_stop = true
push!(policy, PreActStage(), env)
optimise!(policy, PreActStage())
push!(hook, PreActStage(), policy, env)
RLBase.plan!(policy, env) # let the policy see the last observation
break
end
end # end of an episode

push!(policy, PostEpisodeStage(), env) # let the policy see the last observation
optimise!(policy, PostEpisodeStage())
push!(hook, PostEpisodeStage(), policy, env)

end
push!(policy, PostExperimentStage(), env)
push!(hook, PostExperimentStage(), policy, env)
Expand Down
4 changes: 3 additions & 1 deletion src/ReinforcementLearningCore/src/core/stages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ struct PostActStage <: AbstractStage end
Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv) = nothing
Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv, ::Symbol) = nothing

RLBase.optimise!(::AbstractPolicy) = nothing
RLBase.optimise!(policy::P, ::S) where {P<:AbstractPolicy,S<:AbstractStage} = nothing

RLBase.optimise!(policy::P, ::S, batch) where {P<:AbstractPolicy, S<:AbstractStage} = nothing
12 changes: 6 additions & 6 deletions src/ReinforcementLearningCore/src/policies/agent/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ end

Agent(;policy, trajectory, cache = SRT()) = Agent(policy, trajectory, cache)

RLBase.optimise!(agent::Agent) = optimise!(TrajectoryStyle(agent.trajectory), agent)
RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent) =
optimise!(agent.policy, agent.trajectory)
RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = optimise!(TrajectoryStyle(agent.trajectory), agent, stage)
RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} =
optimise!(agent.policy, stage, agent.trajectory)

# already spawn a task to optimise inner policy when initializing the agent
RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent) = nothing
RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing

function RLBase.optimise!(policy::AbstractPolicy, trajectory::Trajectory)
function RLBase.optimise!(policy::AbstractPolicy, stage::S, trajectory::Trajectory) where {S<:AbstractStage}
for batch in trajectory
optimise!(policy, batch)
optimise!(policy, stage, batch)
end
end

Expand Down
11 changes: 8 additions & 3 deletions src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,30 @@ function Base.run(
while !is_stop
reset!(env)
push!(multiagent_policy, PreEpisodeStage(), env)
optimise!(multiagent_policy, PreEpisodeStage())
push!(multiagent_hook, PreEpisodeStage(), multiagent_policy, env)

while !(reset_condition(multiagent_policy, env) || is_stop) # one episode
for player in CurrentPlayerIterator(env)
policy = multiagent_policy[player] # Select appropriate policy
hook = multiagent_hook[player] # Select appropriate hook
push!(policy, PreActStage(), env)
optimise!(policy, PreActStage())
push!(hook, PreActStage(), policy, env)

action = RLBase.plan!(policy, env)
act!(env, action)

optimise!(policy)


push!(policy, PostActStage(), env)
optimise!(policy, PostActStage())
push!(hook, PostActStage(), policy, env)

if check_stop(stop_condition, policy, env)
is_stop = true
push!(multiagent_policy, PreActStage(), env)
optimise!(multiagent_policy, PreActStage())
push!(multiagent_hook, PreActStage(), policy, env)
RLBase.plan!(multiagent_policy, env) # let the policy see the last observation
break
Expand All @@ -142,6 +146,7 @@ function Base.run(
end # end of an episode

push!(multiagent_policy, PostEpisodeStage(), env) # let the policy see the last observation
optimise!(multiagent_policy, PostEpisodeStage())
push!(multiagent_hook, PostEpisodeStage(), multiagent_policy, env)
end
push!(multiagent_policy, PostExperimentStage(), env)
Expand Down Expand Up @@ -225,8 +230,8 @@ function RLBase.plan!(multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEn
return (RLBase.plan!(multiagent[player], env, player) for player in players(env))
end

function RLBase.optimise!(multiagent::MultiAgentPolicy)
function RLBase.optimise!(multiagent::MultiAgentPolicy, stage::S) where {S<:AbstractStage}
for policy in multiagent
RLCore.optimise!(policy)
RLCore.optimise!(policy, stage)
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ end
RLBase.prob(p::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} =
prob(p.explorer, forward(p.learner, env), legal_action_space_mask(env))

RLBase.optimise!(p::QBasedPolicy{L,Ex}, x::NamedTuple) where {L<:AbstractLearner,Ex<:AbstractExplorer} = optimise!(p.learner, x)
RLBase.optimise!(p::QBasedPolicy{L,Ex}, stage::S, x::NamedTuple) where {L<:AbstractLearner,Ex<:AbstractExplorer, S<:AbstractStage} = optimise!(p.learner, x)
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ function RLBase.optimise!(
k => p′
end

function RLBase.optimise!(policy::QBasedPolicy{<:PrioritizedDQNLearner}, trajectory::Trajectory)
function RLBase.optimise!(policy::QBasedPolicy{<:PrioritizedDQNLearner}, ::PostActStage, trajectory::Trajectory)
for batch in trajectory
k, p = optimise!(policy, batch) |> send_to_host
k, p = optimise!(policy, PostActStage(), batch) |> send_to_host
trajectory[:priority, k] = p
end
end
4 changes: 2 additions & 2 deletions src/ReinforcementLearningZoo/src/algorithms/dqns/rainbow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ function project_distribution(supports, weights, target_support, delta_z, vmin,
reshape(sum(projection, dims=1), n_atoms, batch_size)
end

function RLBase.optimise!(policy::QBasedPolicy{<:RainbowLearner}, trajectory::Trajectory)
function RLBase.optimise!(policy::QBasedPolicy{<:RainbowLearner}, ::PostActStage, trajectory::Trajectory)
for batch in trajectory
res = optimise!(policy, batch) |> send_to_host
res = optimise!(policy, PostActStage(), batch) |> send_to_host
if !isnothing(res)
k, p = res
trajectory[:priority, k] = p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ end

function RLBase.optimise!(
p::MPOPolicy,
::PostActStage,
batches::NamedTuple{
(:actor, :critic),
<: Tuple{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ function Base.push!(p::Agent{<:TRPO}, ::PostEpisodeStage, env::AbstractEnv)
empty!(p.trajectory.container)
end

RLBase.optimise!(::Agent{<:TRPO}) = nothing
RLBase.optimise!(::Agent{<:TRPO}, ::PostActStage) = nothing

function RLBase.optimise!::TRPO, episode::Episode)
function RLBase.optimise!::TRPO, ::PostActStage, episode::Episode)
gain = discount_rewards(episode[:reward][:], π.γ)
for inds in Iterators.partition(shuffle.rng, 1:length(episode)), π.batch_size)
optimise!(π, (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds]))
end
end

function RLBase.optimise!(p::TRPO, batch::NamedTuple{(:state, :action, :gain)})
function RLBase.optimise!(p::TRPO, ::PostActStage, batch::NamedTuple{(:state, :action, :gain)})
A = p.approximator
B = p.baseline
s, a, g = map(Array, batch) # !!! FIXME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ function update!(p::Agent{<:VPG}, ::PostEpisodeStage, env::AbstractEnv)
empty!(p.trajectory.container)
end

RLBase.optimise!(::Agent{<:VPG}) = nothing
RLBase.optimise!(::Agent{<:VPG}, ::PostActStage) = nothing

function RLBase.optimise!::VPG, episode::Episode)
function RLBase.optimise!::VPG, ::PostActStage, episode::Episode)
gain = discount_rewards(episode[:reward][:], π.γ)
for inds in Iterators.partition(shuffle.rng, 1:length(episode)), π.batch_size)
optimise!(π, (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds]))
optimise!(π, PostActStage(), (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds]))
end
end

function RLBase.optimise!(p::VPG, batch::NamedTuple{(:state, :action, :gain)})
function RLBase.optimise!(p::VPG, ::PostActStage, batch::NamedTuple{(:state, :action, :gain)})
A = p.approximator
B = p.baseline
s, a, g = map(Array, batch) # !!! FIXME
Expand Down

0 comments on commit 32604f2

Please sign in to comment.