diff --git a/src/ReinforcementLearning.jl b/src/ReinforcementLearning.jl index 5f1bf7c9a..21dbe6ac5 100644 --- a/src/ReinforcementLearning.jl +++ b/src/ReinforcementLearning.jl @@ -1,9 +1,11 @@ module ReinforcementLearning using DataStructures, Parameters, SparseArrays, LinearAlgebra, Distributed, -Statistics, Dates, Compat, Requires, StatsBase +Statistics, Dates, Requires, StatsBase import Statistics: mean -import ReinforcementLearningBase: interact!, getstate, reset!, plotenv, sample +import ReinforcementLearningBase: interact!, getstate, reset!, plotenv, +actionspace, sample + using Random: seed! function __init__() diff --git a/src/buffers.jl b/src/buffers.jl index 88895068c..b535f8d3f 100644 --- a/src/buffers.jl +++ b/src/buffers.jl @@ -88,7 +88,7 @@ function ArrayCircularBuffer(arraytype, datatype, elemshape, capacity) convert(Dims, (elemshape..., capacity)))), capacity, 0, 0, false) end -import Base.push!, Base.view, Compat.lastindex, Base.getindex +import Base.push!, Base.view, Base.lastindex, Base.getindex for N in 2:5 @eval @__MODULE__() begin function push!(a::ArrayCircularBuffer{<:AbstractArray{T, $N}}, x) where T diff --git a/src/callbacks.jl b/src/callbacks.jl index fc6cfdd68..c9bf428e3 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -219,15 +219,15 @@ function callback!(c::EvaluateGreedy, rlsetup, sraw, a, r, done) rlsetup.fillbuffer = false c.rlsetupcallbacks = rlsetup.callbacks rlsetup.callbacks = [c] - c.rlsetuppolicy = rlsetup.policy - rlsetup.policy = greedypolicy(rlsetup.policy) + c.rlsetuppolicy = deepcopy(rlsetup.policy) + greedify!(rlsetup.policy) end end getvalue(c::EvaluateGreedy) = c.values export EvaluateGreedy, Step, Episode -greedypolicy(p::EpsilonGreedyPolicy{T}) where T = EpsilonGreedyPolicy{T}(0.) -greedypolicy(p::SoftmaxPolicy) = SoftmaxPolicy(Inf) +greedify!(p::EpsilonGreedyPolicy) where T = p.ϵ = 0 +greedify!(p::SoftmaxPolicy) = p.β = Inf import FileIO:save """ @@ -328,4 +328,4 @@ function callback!(c::Visualize, rlsetup, s, a, r, done) plotenv(rlsetup.environment) sleep(c.wait) end -plotenv(env, s, a, r, d) = warn("Visualization not implemented for environments of type $(typeof(env)).") +plotenv(env) = warn("Visualization not implemented for environments of type $(typeof(env)).") diff --git a/src/forced.jl b/src/forced.jl index 955bc50a3..de7e35fd4 100644 --- a/src/forced.jl +++ b/src/forced.jl @@ -15,10 +15,11 @@ export ForcedEpisode ForcedEpisode(states, dones, rewards) = ForcedEpisode(1, states, dones, rewards) function interact!(env::ForcedEpisode, a) env.t += 1 - env.states[env.t], env.rewards[env.t], env.dones[env.t] + (observation = env.states[env.t], reward = env.rewards[env.t], + isdone = env.dones[env.t]) end function reset!(env::ForcedEpisode) env.t = 1 - env.states[1] + (obervation = env.states[1], ) end -getstate(env::ForcedEpisode) = (env.states[env.t], env.dones[env.t]) +getstate(env::ForcedEpisode) = (observation = env.states[env.t], isdone = env.dones[env.t]) diff --git a/src/learn.jl b/src/learn.jl index 08c87a258..7b58cc00d 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -4,7 +4,7 @@ s, r, done = preprocess(preprocessor, s0, r0, done0) if fillbuffer; pushreturn!(buffer, r, done) end if done - s0 = reset!(environment) + s0, = reset!(environment) s = preprocessstate(preprocessor, s0) end if fillbuffer; pushstate!(buffer, s) end @@ -16,7 +16,7 @@ end @unpack learner, policy, buffer, preprocessor, environment, fillbuffer = rlsetup if isempty(buffer.actions) sraw, done = getstate(environment) - if done; sraw = reset!(environment); end + if done; sraw, = reset!(environment); end s = preprocessstate(preprocessor, sraw) if fillbuffer; pushstate!(buffer, s) end a = policy(s) diff --git a/src/policies.jl b/src/policies.jl index ff2279df6..bc8f0fae5 100644 --- a/src/policies.jl +++ b/src/policies.jl @@ -73,8 +73,13 @@ function EpsilonGreedyPolicy(ϵ, actionspace::Ta, Q::Tf; EpsilonGreedyPolicy{kind, Ta, Tf}(ϵ, actionspace, Q) end export EpsilonGreedyPolicy -(p::EpsilonGreedyPolicy)(s) = rand() < p.ϵ ? sample(p.actionspace) : - samplegreedyaction(p, p.Q(s)) +function (p::EpsilonGreedyPolicy)(s) + if rand() < p.ϵ + rand(1:p.actionspace.n) # sample(actionspace) does not work currently because DQN expects actions in 1:n + else + samplegreedyaction(p, p.Q(s)) + end +end import Base.maximum, Base.isequal diff --git a/src/rlsetup.jl b/src/rlsetup.jl index 3e9d6e713..3639b8ed9 100644 --- a/src/rlsetup.jl +++ b/src/rlsetup.jl @@ -5,7 +5,7 @@ stoppingcriterion::Ts preprocessor::Tpp = NoPreprocessor() buffer::Tb = defaultbuffer(learner, environment, preprocessor) - policy::Tp = defaultpolicy(learner, environment.actionspace, buffer) + policy::Tp = defaultpolicy(learner, actionspace(environment), buffer) callbacks::Array{Any, 1} = [] islearning::Bool = true fillbuffer::Bool = islearning @@ -16,7 +16,7 @@ stoppingcriterion::Ts preprocessor::Tpp = NoPreprocessor() buffer::Tb = defaultbuffer(learner, environment, preprocessor) - policy::Tp = defaultpolicy(learner, environment.actionspace, buffer) + policy::Tp = defaultpolicy(learner, actionspace(environment), buffer) callbacks::Array{Any, 1} = [] islearning::Bool = true fillbuffer::Bool = islearning diff --git a/test/learn.jl b/test/learn.jl index bfaa12d20..09185394b 100644 --- a/test/learn.jl +++ b/test/learn.jl @@ -1,4 +1,4 @@ -import ReinforcementLearning: reset! +import ReinforcementLearningBase: reset! function testlearn() mdp = MDP() learner = Sarsa()