Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add plotting for cartpole and mountaincar with Plots.jl #309

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ function __init__()
@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" include(
"environments/3rd_party/AcrobotEnv.jl",
)
@require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include(
"plots.jl",
)


end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,31 +123,3 @@ function (env::CartPoleEnv)(a)
end

Random.seed!(env::CartPoleEnv, seed) = Random.seed!(env.rng, seed)

function plotendofepisode(x, y, d)
if d
setmarkercolorind(7)
setmarkertype(-1)
setmarkersize(6)
polymarker([x], [y])
end
return nothing
end

function GR.plot(env::CartPoleEnv)
s, a, d = env.state, env.action, env.done
x, xdot, theta, thetadot = s
l = 2 * env.params.halflength
clearws()
setviewport(0, 1, 0, 1)
xthreshold = env.params.xthreshold
setwindow(-xthreshold, xthreshold, -.1, l + 0.1)
fillarea([x - 0.5, x - 0.5, x + 0.5, x + 0.5], [-.05, 0, 0, -.05])
setlinecolorind(4)
setlinewidth(3)
polyline([x, x + l * sin(theta)], [0, l * cos(theta)])
setlinecolorind(2)
drawarrow(x + (a == 1) - 0.5, -.025, x + 1.4 * (a == 1) - 0.7, -.025)
plotendofepisode(xthreshold - 0.2, l, d)
updatews()
end
Original file line number Diff line number Diff line change
Expand Up @@ -135,37 +135,3 @@ function _step!(env::MountainCarEnv, force)
env.state[2] = v
nothing
end

# adapted from https://github.com/JuliaML/Reinforce.jl/blob/master/src/envs/mountain_car.jl
height(xs) = sin(3 * xs) * 0.45 + 0.55
rotate(xs, ys, θ) = xs * cos(θ) - ys * sin(θ), ys * cos(θ) + xs * sin(θ)
translate(xs, ys, t) = xs .+ t[1], ys .+ t[2]

function GR.plot(env::MountainCarEnv)
s = env.state
d = env.done
clearws()
setviewport(0, 1, 0, 1)
setwindow(
env.params.min_pos - 0.1,
env.params.max_pos + 0.2,
-.1,
height(env.params.max_pos) + 0.2,
)
xs = LinRange(env.params.min_pos, env.params.max_pos, 100)
ys = height.(xs)
polyline(xs, ys)
x = s[1]
θ = cos(3 * x)
carwidth = 0.05
carheight = carwidth / 2
clearance = 0.2 * carheight
xs = [-carwidth / 2, -carwidth / 2, carwidth / 2, carwidth / 2]
ys = [0, carheight, carheight, 0]
ys .+= clearance
xs, ys = rotate(xs, ys, θ)
xs, ys = translate(xs, ys, [x, height(x)])
fillarea(xs, ys)
plotendofepisode(env.params.max_pos + 0.1, 0, d)
updatews()
end
88 changes: 88 additions & 0 deletions src/ReinforcementLearningEnvironments/src/plots.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import .Plots.plot
import .Plots.plot!

function plot(env::CartPoleEnv)
s, a, d = env.state, env.action, env.done
x, xdot, theta, thetadot = s
l = 2 * env.params.halflength
xthreshold = env.params.xthreshold
# set the frame
p = plot(
xlims=(-xthreshold, xthreshold),
ylims=(-.1, l + 0.1),
legend=false,
border=:none,
)
# plot the cart
plot!([x - 0.5, x - 0.5, x + 0.5, x + 0.5], [-.05, 0, 0, -.05];
seriestype=:shape,
)
# plot the pole
plot!([x, x + l * sin(theta)], [0, l * cos(theta)];
linewidth=3,
)
# plot the arrow
plot!([x + (a == 1) - 0.5, x + 1.4 * (a == 1)-0.7], [ -.025, -.025];
linewidth=3,
arrow=true,
color=2,
)
# if done plot pink circle in top right
if d
plot!([xthreshold - 0.2], [l];
marker=:circle,
markersize=20,
markerstrokewidth=0.,
color=:pink,
)
end

p
end


# adapted from https://github.com/JuliaML/Reinforce.jl/blob/master/src/envs/mountain_car.jl
height(xs) = sin(3 * xs) * 0.45 + 0.55
rotate(xs, ys, θ) = xs * cos(θ) - ys * sin(θ), ys * cos(θ) + xs * sin(θ)
translate(xs, ys, t) = xs .+ t[1], ys .+ t[2]

function plot(env::MountainCarEnv)
s = env.state
d = env.done

p = plot(
xlims=(env.params.min_pos - 0.1, env.params.max_pos + 0.2),
ylims=(-.1, height(env.params.max_pos) + 0.2),
legend=false,
border=:none,
)
# plot the terrain
xs = LinRange(env.params.min_pos, env.params.max_pos, 100)
ys = height.(xs)
plot!(xs, ys)

# plot the car
x = s[1]
θ = cos(3 * x)
carwidth = 0.05
carheight = carwidth / 2
clearance = 0.2 * carheight
xs = [-carwidth / 2, -carwidth / 2, carwidth / 2, carwidth / 2]
ys = [0, carheight, carheight, 0]
ys .+= clearance
xs, ys = rotate(xs, ys, θ)
xs, ys = translate(xs, ys, [x, height(x)])
plot!(xs, ys; seriestype=:shape)

# if done plot pink circle in top right
if d
plot!([xthreshold - 0.2], [l];
marker=:circle,
markersize=20,
markerstrokewidth=0.,
color=:pink,
)
end

p
end