Skip to content

Commit

Permalink
Merge pull request #18 from JuliaReinforcementLearning/fix_examples
Browse files Browse the repository at this point in the history
fix examples
  • Loading branch information
jbrea authored Sep 21, 2018
2 parents 5342d7c + 0beba5f commit 2126960
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 18 deletions.
6 changes: 4 additions & 2 deletions src/ReinforcementLearning.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module ReinforcementLearning

using DataStructures, Parameters, SparseArrays, LinearAlgebra, Distributed,
Statistics, Dates, Compat, Requires, StatsBase
Statistics, Dates, Requires, StatsBase
import Statistics: mean
import ReinforcementLearningBase: interact!, getstate, reset!, plotenv, sample
import ReinforcementLearningBase: interact!, getstate, reset!, plotenv,
actionspace, sample


using Random: seed!
function __init__()
Expand Down
2 changes: 1 addition & 1 deletion src/buffers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function ArrayCircularBuffer(arraytype, datatype, elemshape, capacity)
convert(Dims, (elemshape..., capacity)))),
capacity, 0, 0, false)
end
import Base.push!, Base.view, Compat.lastindex, Base.getindex
import Base.push!, Base.view, Base.lastindex, Base.getindex
for N in 2:5
@eval @__MODULE__() begin
function push!(a::ArrayCircularBuffer{<:AbstractArray{T, $N}}, x) where T
Expand Down
10 changes: 5 additions & 5 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,15 @@ function callback!(c::EvaluateGreedy, rlsetup, sraw, a, r, done)
rlsetup.fillbuffer = false
c.rlsetupcallbacks = rlsetup.callbacks
rlsetup.callbacks = [c]
c.rlsetuppolicy = rlsetup.policy
rlsetup.policy = greedypolicy(rlsetup.policy)
c.rlsetuppolicy = deepcopy(rlsetup.policy)
greedify!(rlsetup.policy)
end
end
getvalue(c::EvaluateGreedy) = c.values

export EvaluateGreedy, Step, Episode
greedypolicy(p::EpsilonGreedyPolicy{T}) where T = EpsilonGreedyPolicy{T}(0.)
greedypolicy(p::SoftmaxPolicy) = SoftmaxPolicy(Inf)
greedify!(p::EpsilonGreedyPolicy) where T = p.ϵ = 0
greedify!(p::SoftmaxPolicy) = p.β = Inf

import FileIO:save
"""
Expand Down Expand Up @@ -328,4 +328,4 @@ function callback!(c::Visualize, rlsetup, s, a, r, done)
plotenv(rlsetup.environment)
sleep(c.wait)
end
plotenv(env, s, a, r, d) = warn("Visualization not implemented for environments of type $(typeof(env)).")
plotenv(env) = warn("Visualization not implemented for environments of type $(typeof(env)).")
7 changes: 4 additions & 3 deletions src/forced.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ export ForcedEpisode
ForcedEpisode(states, dones, rewards) = ForcedEpisode(1, states, dones, rewards)
function interact!(env::ForcedEpisode, a)
env.t += 1
env.states[env.t], env.rewards[env.t], env.dones[env.t]
(observation = env.states[env.t], reward = env.rewards[env.t],
isdone = env.dones[env.t])
end
function reset!(env::ForcedEpisode)
env.t = 1
env.states[1]
(obervation = env.states[1], )
end
getstate(env::ForcedEpisode) = (env.states[env.t], env.dones[env.t])
getstate(env::ForcedEpisode) = (observation = env.states[env.t], isdone = env.dones[env.t])
4 changes: 2 additions & 2 deletions src/learn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
s, r, done = preprocess(preprocessor, s0, r0, done0)
if fillbuffer; pushreturn!(buffer, r, done) end
if done
s0 = reset!(environment)
s0, = reset!(environment)
s = preprocessstate(preprocessor, s0)
end
if fillbuffer; pushstate!(buffer, s) end
Expand All @@ -16,7 +16,7 @@ end
@unpack learner, policy, buffer, preprocessor, environment, fillbuffer = rlsetup
if isempty(buffer.actions)
sraw, done = getstate(environment)
if done; sraw = reset!(environment); end
if done; sraw, = reset!(environment); end
s = preprocessstate(preprocessor, sraw)
if fillbuffer; pushstate!(buffer, s) end
a = policy(s)
Expand Down
9 changes: 7 additions & 2 deletions src/policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,13 @@ function EpsilonGreedyPolicy(ϵ, actionspace::Ta, Q::Tf;
EpsilonGreedyPolicy{kind, Ta, Tf}(ϵ, actionspace, Q)
end
export EpsilonGreedyPolicy
(p::EpsilonGreedyPolicy)(s) = rand() < p.ϵ ? sample(p.actionspace) :
samplegreedyaction(p, p.Q(s))
function (p::EpsilonGreedyPolicy)(s)
if rand() < p.ϵ
rand(1:p.actionspace.n) # sample(actionspace) does not work currently because DQN expects actions in 1:n
else
samplegreedyaction(p, p.Q(s))
end
end


import Base.maximum, Base.isequal
Expand Down
4 changes: 2 additions & 2 deletions src/rlsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
stoppingcriterion::Ts
preprocessor::Tpp = NoPreprocessor()
buffer::Tb = defaultbuffer(learner, environment, preprocessor)
policy::Tp = defaultpolicy(learner, environment.actionspace, buffer)
policy::Tp = defaultpolicy(learner, actionspace(environment), buffer)
callbacks::Array{Any, 1} = []
islearning::Bool = true
fillbuffer::Bool = islearning
Expand All @@ -16,7 +16,7 @@
stoppingcriterion::Ts
preprocessor::Tpp = NoPreprocessor()
buffer::Tb = defaultbuffer(learner, environment, preprocessor)
policy::Tp = defaultpolicy(learner, environment.actionspace, buffer)
policy::Tp = defaultpolicy(learner, actionspace(environment), buffer)
callbacks::Array{Any, 1} = []
islearning::Bool = true
fillbuffer::Bool = islearning
Expand Down
2 changes: 1 addition & 1 deletion test/learn.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import ReinforcementLearning: reset!
import ReinforcementLearningBase: reset!
function testlearn()
mdp = MDP()
learner = Sarsa()
Expand Down

0 comments on commit 2126960

Please sign in to comment.