From ea00fdfc0cc4cae92013211d63767bc081b2b603 Mon Sep 17 00:00:00 2001 From: Jeremiah <4462211+jeremiahpslewis@users.noreply.github.com> Date: Fri, 9 Jun 2023 08:14:11 +0200 Subject: [PATCH] Sketch out optimise! refactor (#899) --- src/ReinforcementLearningCore/src/core/run.jl | 9 +++++++-- src/ReinforcementLearningCore/src/core/stages.jl | 4 +++- .../src/policies/agent/base.jl | 12 ++++++------ .../src/policies/agent/multi_agent.jl | 11 ++++++++--- .../src/policies/q_based_policy.jl | 2 +- .../src/algorithms/dqns/prioritized_dqn.jl | 4 ++-- .../src/algorithms/dqns/rainbow.jl | 4 ++-- .../src/algorithms/policy_gradient/mpo.jl | 1 + .../src/algorithms/policy_gradient/trpo.jl | 6 +++--- .../src/algorithms/policy_gradient/vpg.jl | 8 ++++---- 10 files changed, 37 insertions(+), 24 deletions(-) diff --git a/src/ReinforcementLearningCore/src/core/run.jl b/src/ReinforcementLearningCore/src/core/run.jl index 1779bd4dc..1dae0dcad 100644 --- a/src/ReinforcementLearningCore/src/core/run.jl +++ b/src/ReinforcementLearningCore/src/core/run.jl @@ -89,23 +89,26 @@ function _run(policy::AbstractPolicy, while !is_stop reset!(env) push!(policy, PreEpisodeStage(), env) + optimise!(policy, PreActStage()) push!(hook, PreEpisodeStage(), policy, env) + while !reset_condition(policy, env) # one episode push!(policy, PreActStage(), env) + optimise!(policy, PreActStage()) push!(hook, PreActStage(), policy, env) action = RLBase.plan!(policy, env) act!(env, action) - optimise!(policy) - push!(policy, PostActStage(), env) + optimise!(policy, PostActStage()) push!(hook, PostActStage(), policy, env) if check_stop(stop_condition, policy, env) is_stop = true push!(policy, PreActStage(), env) + optimise!(policy, PreActStage()) push!(hook, PreActStage(), policy, env) RLBase.plan!(policy, env) # let the policy see the last observation break @@ -113,7 +116,9 @@ function _run(policy::AbstractPolicy, end # end of an episode push!(policy, PostEpisodeStage(), env) # let the policy see the last observation + optimise!(policy, PostEpisodeStage()) push!(hook, PostEpisodeStage(), policy, env) + end push!(policy, PostExperimentStage(), env) push!(hook, PostExperimentStage(), policy, env) diff --git a/src/ReinforcementLearningCore/src/core/stages.jl b/src/ReinforcementLearningCore/src/core/stages.jl index 22ab0bfae..52f2a30a9 100644 --- a/src/ReinforcementLearningCore/src/core/stages.jl +++ b/src/ReinforcementLearningCore/src/core/stages.jl @@ -19,4 +19,6 @@ struct PostActStage <: AbstractStage end Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv) = nothing Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv, ::Symbol) = nothing -RLBase.optimise!(::AbstractPolicy) = nothing +RLBase.optimise!(policy::P, ::S) where {P<:AbstractPolicy,S<:AbstractStage} = nothing + +RLBase.optimise!(policy::P, ::S, batch) where {P<:AbstractPolicy, S<:AbstractStage} = nothing diff --git a/src/ReinforcementLearningCore/src/policies/agent/base.jl b/src/ReinforcementLearningCore/src/policies/agent/base.jl index 4f81f7819..dc7798dc2 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/base.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/base.jl @@ -39,16 +39,16 @@ end Agent(;policy, trajectory, cache = SRT()) = Agent(policy, trajectory, cache) -RLBase.optimise!(agent::Agent) = optimise!(TrajectoryStyle(agent.trajectory), agent) -RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent) = - optimise!(agent.policy, agent.trajectory) +RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = optimise!(TrajectoryStyle(agent.trajectory), agent, stage) +RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = + optimise!(agent.policy, stage, agent.trajectory) # already spawn a task to optimise inner policy when initializing the agent -RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent) = nothing +RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing -function RLBase.optimise!(policy::AbstractPolicy, trajectory::Trajectory) +function RLBase.optimise!(policy::AbstractPolicy, stage::S, trajectory::Trajectory) where {S<:AbstractStage} for batch in trajectory - optimise!(policy, batch) + optimise!(policy, stage, batch) end end diff --git a/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl b/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl index bad13c2e1..ee42acc11 100644 --- a/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl +++ b/src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl @@ -110,6 +110,7 @@ function Base.run( while !is_stop reset!(env) push!(multiagent_policy, PreEpisodeStage(), env) + optimise!(multiagent_policy, PreEpisodeStage()) push!(multiagent_hook, PreEpisodeStage(), multiagent_policy, env) while !(reset_condition(multiagent_policy, env) || is_stop) # one episode @@ -117,19 +118,22 @@ function Base.run( policy = multiagent_policy[player] # Select appropriate policy hook = multiagent_hook[player] # Select appropriate hook push!(policy, PreActStage(), env) + optimise!(policy, PreActStage()) push!(hook, PreActStage(), policy, env) action = RLBase.plan!(policy, env) act!(env, action) - optimise!(policy) + push!(policy, PostActStage(), env) + optimise!(policy, PostActStage()) push!(hook, PostActStage(), policy, env) if check_stop(stop_condition, policy, env) is_stop = true push!(multiagent_policy, PreActStage(), env) + optimise!(multiagent_policy, PreActStage()) push!(multiagent_hook, PreActStage(), policy, env) RLBase.plan!(multiagent_policy, env) # let the policy see the last observation break @@ -142,6 +146,7 @@ function Base.run( end # end of an episode push!(multiagent_policy, PostEpisodeStage(), env) # let the policy see the last observation + optimise!(multiagent_policy, PostEpisodeStage()) push!(multiagent_hook, PostEpisodeStage(), multiagent_policy, env) end push!(multiagent_policy, PostExperimentStage(), env) @@ -225,8 +230,8 @@ function RLBase.plan!(multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEn return (RLBase.plan!(multiagent[player], env, player) for player in players(env)) end -function RLBase.optimise!(multiagent::MultiAgentPolicy) +function RLBase.optimise!(multiagent::MultiAgentPolicy, stage::S) where {S<:AbstractStage} for policy in multiagent - RLCore.optimise!(policy) + RLCore.optimise!(policy, stage) end end diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl index 79881471e..cb3376038 100644 --- a/src/ReinforcementLearningCore/src/policies/q_based_policy.jl +++ b/src/ReinforcementLearningCore/src/policies/q_based_policy.jl @@ -37,4 +37,4 @@ end RLBase.prob(p::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} = prob(p.explorer, forward(p.learner, env), legal_action_space_mask(env)) -RLBase.optimise!(p::QBasedPolicy{L,Ex}, x::NamedTuple) where {L<:AbstractLearner,Ex<:AbstractExplorer} = optimise!(p.learner, x) +RLBase.optimise!(p::QBasedPolicy{L,Ex}, stage::S, x::NamedTuple) where {L<:AbstractLearner,Ex<:AbstractExplorer, S<:AbstractStage} = optimise!(p.learner, x) diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl index 1d587afe6..5a5f4a7b8 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl @@ -71,9 +71,9 @@ function RLBase.optimise!( k => p′ end -function RLBase.optimise!(policy::QBasedPolicy{<:PrioritizedDQNLearner}, trajectory::Trajectory) +function RLBase.optimise!(policy::QBasedPolicy{<:PrioritizedDQNLearner}, ::PostActStage, trajectory::Trajectory) for batch in trajectory - k, p = optimise!(policy, batch) |> send_to_host + k, p = optimise!(policy, PostActStage(), batch) |> send_to_host trajectory[:priority, k] = p end end diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/rainbow.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/rainbow.jl index c99c5b209..a10a86416 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/rainbow.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/rainbow.jl @@ -139,9 +139,9 @@ function project_distribution(supports, weights, target_support, delta_z, vmin, reshape(sum(projection, dims=1), n_atoms, batch_size) end -function RLBase.optimise!(policy::QBasedPolicy{<:RainbowLearner}, trajectory::Trajectory) +function RLBase.optimise!(policy::QBasedPolicy{<:RainbowLearner}, ::PostActStage, trajectory::Trajectory) for batch in trajectory - res = optimise!(policy, batch) |> send_to_host + res = optimise!(policy, PostActStage(), batch) |> send_to_host if !isnothing(res) k, p = res trajectory[:priority, k] = p diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/mpo.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/mpo.jl index 979422bd7..4decca982 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/mpo.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/mpo.jl @@ -93,6 +93,7 @@ end function RLBase.optimise!( p::MPOPolicy, + ::PostActStage, batches::NamedTuple{ (:actor, :critic), <: Tuple{ diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/trpo.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/trpo.jl index eb9efca86..572abd962 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/trpo.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/trpo.jl @@ -39,16 +39,16 @@ function Base.push!(p::Agent{<:TRPO}, ::PostEpisodeStage, env::AbstractEnv) empty!(p.trajectory.container) end -RLBase.optimise!(::Agent{<:TRPO}) = nothing +RLBase.optimise!(::Agent{<:TRPO}, ::PostActStage) = nothing -function RLBase.optimise!(π::TRPO, episode::Episode) +function RLBase.optimise!(π::TRPO, ::PostActStage, episode::Episode) gain = discount_rewards(episode[:reward][:], π.γ) for inds in Iterators.partition(shuffle(π.rng, 1:length(episode)), π.batch_size) optimise!(π, (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds])) end end -function RLBase.optimise!(p::TRPO, batch::NamedTuple{(:state, :action, :gain)}) +function RLBase.optimise!(p::TRPO, ::PostActStage, batch::NamedTuple{(:state, :action, :gain)}) A = p.approximator B = p.baseline s, a, g = map(Array, batch) # !!! FIXME diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/vpg.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/vpg.jl index b243bf158..d9a3ec0c7 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/vpg.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/vpg.jl @@ -36,16 +36,16 @@ function update!(p::Agent{<:VPG}, ::PostEpisodeStage, env::AbstractEnv) empty!(p.trajectory.container) end -RLBase.optimise!(::Agent{<:VPG}) = nothing +RLBase.optimise!(::Agent{<:VPG}, ::PostActStage) = nothing -function RLBase.optimise!(π::VPG, episode::Episode) +function RLBase.optimise!(π::VPG, ::PostActStage, episode::Episode) gain = discount_rewards(episode[:reward][:], π.γ) for inds in Iterators.partition(shuffle(π.rng, 1:length(episode)), π.batch_size) - optimise!(π, (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds])) + optimise!(π, PostActStage(), (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds])) end end -function RLBase.optimise!(p::VPG, batch::NamedTuple{(:state, :action, :gain)}) +function RLBase.optimise!(p::VPG, ::PostActStage, batch::NamedTuple{(:state, :action, :gain)}) A = p.approximator B = p.baseline s, a, g = map(Array, batch) # !!! FIXME