Skip to content

Commit 6060771

Browse files
authored
add stock trading env (#428)
* add stock trading env * test DDPG & PPO * update words * update version
1 parent d5a397a commit 6060771

File tree

14 files changed

+228
-13
lines changed

14 files changed

+228
-13
lines changed

.cspell/cspell.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@
9898
"boxoban",
9999
"DATADEPS",
100100
"umaze",
101-
"pybullet"
101+
"pybullet",
102+
"turbulences"
102103
],
103104
"ignoreWords": [],
104105
"minWordLength": 5,

NEWS.md

+17
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,21 @@
2727

2828
### ReinforcementLearningBase.jl
2929

30+
#### v0.9.6
31+
32+
- Implement `Base.:(==)` for `Space`. [#428](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/428)
33+
3034
#### v0.9.5
3135

3236
- Add default `Base.:(==)` and `Base.hash` method for `AbstractEnv`. [#348](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/348)
3337

3438
### ReinforcementLearningCore.jl
3539

40+
#### v0.8.3
41+
42+
- Add extra two optional keyword arguments (`min_σ` and `max_σ`) in
43+
`GaussianNetwork` to clip the output of `logσ`. [#428](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/428)
44+
3645
#### v0.8.2
3746

3847
- Add GaussianNetwork and DuelingNetwork into ReinforcementLearningCore.jl as general components. [#370](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/370)
@@ -60,6 +69,14 @@
6069

6170
### ReinforcementLearningEnvironments.jl
6271

72+
#### v0.6.3
73+
74+
- Add `StockTradingEnv` from the paper [Deep Reinforcement Learning for
75+
Automated Stock Trading: An Ensemble
76+
Strategy](https://github.com/AI4Finance-LLC/Deep-Reinforcement-Learning-for-Automated-Stock-Trading-Ensemble-Strategy-ICAIF-2020).
77+
This environment is a good testbed for multi-continuous action space
78+
algorithms. [#428](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/428)
79+
6380
#### v0.6.2
6481

6582
- Add `SequentialEnv` environment wrapper to turn a simultaneous environment

docs/Manifest.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
975975
version = "0.8.2"
976976

977977
[[ReinforcementLearningEnvironments]]
978-
deps = ["IntervalSets", "MacroTools", "Markdown", "Random", "ReinforcementLearningBase", "Requires", "StatsBase"]
978+
deps = ["DelimitedFiles", "IntervalSets", "LinearAlgebra", "MacroTools", "Markdown", "Pkg", "Random", "ReinforcementLearningBase", "Requires", "StatsBase"]
979979
path = "../src/ReinforcementLearningEnvironments"
980980
uuid = "25e41dd2-4622-11e9-1641-f1adca772921"
981981
version = "0.6.2"

src/ReinforcementLearningBase/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReinforcementLearningBase"
22
uuid = "e575027e-6cd6-5018-9292-cdc6200d2b44"
33
authors = ["Johanni Brea <[email protected]>", "Jun Tian <[email protected]>"]
4-
version = "0.9.6"
4+
version = "0.9.7"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/ReinforcementLearningBase/src/base.jl

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct Space{T}
2626
s::T
2727
end
2828

29+
Base.:(==)(x::Space, y::Space) = x.s == y.s
2930
Base.similar(s::Space, args...) = Space(similar(s.s, args...))
3031
Base.getindex(s::Space, args...) = getindex(s.s, args...)
3132
Base.setindex!(s::Space, args...) = setindex!(s.s, args...)

src/ReinforcementLearningCore/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReinforcementLearningCore"
22
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
33
authors = ["Jun Tian <[email protected]>"]
4-
version = "0.8.2"
4+
version = "0.8.3"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl

+8-5
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,18 @@ end
6868
#####
6969

7070
"""
71-
GaussianNetwork(;pre=identity, μ, logσ)
71+
GaussianNetwork(;pre=identity, μ, logσ, min_σ=0f0, max_σ=Inf32)
7272
73-
Returns `μ` and `logσ` when called.
74-
Create a distribution to sample from
75-
using `Normal.(μ, exp.(logσ))`.
73+
Returns `μ` and `logσ` when called. Create a distribution to sample from using
74+
`Normal.(μ, exp.(logσ))`. `min_σ` and `max_σ` are used to clip the output from
75+
`logσ`.
7676
"""
7777
Base.@kwdef struct GaussianNetwork{P,U,S}
7878
pre::P = identity
7979
μ::U
8080
logσ::S
81+
min_σ::Float32 = 0f0
82+
max_σ::Float32 = Inf32
8183
end
8284

8385
Flux.@functor GaussianNetwork
@@ -91,7 +93,8 @@ This function is compatible with a multidimensional action space. When outputtin
9193
"""
9294
function (model::GaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=false, is_return_log_prob::Bool=false)
9395
x = model.pre(state)
94-
μ, logσ = model.μ(x), model.logσ(x)
96+
μ, raw_logσ = model.μ(x), model.logσ(x)
97+
logσ = clamp.(raw_logσ, log(model.min_σ), log(model.max_σ))
9598
if is_sampling
9699
π_dist = Normal.(μ, exp.(logσ))
97100
z = rand.(rng, π_dist)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[stock_trading_data]
2+
git-tree-sha1 = "c2ef05aa70df44749bd43b2ab9a558ea6829b32b"
3+
4+
[[stock_trading_data.download]]
5+
url = "https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/releases/download/v0.9.0/stock_trading_data.tar.gz"
6+
sha256 = "2abc589a9dfb5b2134ee531152bd361b08629938ea3bf53fe56270517d732c89"

src/ReinforcementLearningEnvironments/Manifest.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
160160
deps = ["AbstractTrees", "CommonRLInterface", "Markdown", "Random", "Test"]
161161
path = "../ReinforcementLearningBase"
162162
uuid = "e575027e-6cd6-5018-9292-cdc6200d2b44"
163-
version = "0.9.5"
163+
version = "0.9.6"
164164

165165
[[Requires]]
166166
deps = ["UUIDs"]

src/ReinforcementLearningEnvironments/Project.toml

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
name = "ReinforcementLearningEnvironments"
22
uuid = "25e41dd2-4622-11e9-1641-f1adca772921"
33
authors = ["Jun Tian <[email protected]>"]
4-
version = "0.6.2"
4+
version = "0.6.3"
55

66
[deps]
7+
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
78
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
9+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
810
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
911
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
12+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1013
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1114
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
1215
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -30,4 +33,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3033
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3134

3235
[targets]
33-
test = ["ArcadeLearningEnvironment", "OpenSpiel", "OrdinaryDiffEq", "PyCall", "StableRNGs", "Statistics", "Test"]
36+
test = ["ArcadeLearningEnvironment", "OpenSpiel", "OrdinaryDiffEq", "PyCall", "StableRNGs", "Statistics", "Test"]
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1+
include("wrappers/wrappers.jl")
12
include("examples/examples.jl")
23
include("non_interactive/non_interactive.jl")
3-
include("wrappers/wrappers.jl")
44
include("3rd_party/structs.jl")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
export StockTradingEnv, StockTradingEnvWithTurbulence
2+
3+
using Pkg.Artifacts
4+
using DelimitedFiles
5+
using LinearAlgebra:dot
6+
using IntervalSets
7+
8+
function load_default_stock_data(s)
9+
if s == "prices.csv" || s == "features.csv"
10+
data, _ = readdlm(joinpath(artifact"stock_trading_data", s), ',', header=true)
11+
collect(data')
12+
elseif s == "turbulence.csv"
13+
readdlm(joinpath(artifact"stock_trading_data", "turbulence.csv")) |> vec
14+
else
15+
@error "unknown dataset $s"
16+
end
17+
end
18+
19+
mutable struct StockTradingEnv{F<:AbstractMatrix{Float64}, P<:AbstractMatrix{Float64}} <: AbstractEnv
20+
features::F
21+
prices::P
22+
HMAX_NORMALIZE::Float32
23+
TRANSACTION_FEE_PERCENT::Float32
24+
REWARD_SCALING::Float32
25+
initial_account_balance::Float32
26+
state::Vector{Float32}
27+
total_cost::Float32
28+
day::Int
29+
first_day::Int
30+
last_day::Int
31+
daily_reward::Float32
32+
end
33+
34+
_n_stocks(env::StockTradingEnv) = size(env.prices, 1)
35+
_prices(env::StockTradingEnv) = @view(env.state[2:1+_n_stocks(env)])
36+
_holds(env::StockTradingEnv) = @view(env.state[2+_n_stocks(env):_n_stocks(env)*2+1])
37+
_features(env::StockTradingEnv) = @view(env.state[_n_stocks(env)*2+2:end])
38+
_balance(env::StockTradingEnv) = @view env.state[1]
39+
_total_asset(env::StockTradingEnv) = env.state[1] + dot(_prices(env), _holds(env))
40+
41+
"""
42+
StockTradingEnv(;kw...)
43+
44+
This environment is originally provided in [Deep Reinforcement Learning for Automated Stock Trading: An Ensemble Strategy](https://github.com/AI4Finance-LLC/Deep-Reinforcement-Learning-for-Automated-Stock-Trading-Ensemble-Strategy-ICAIF-2020)
45+
46+
# Keyword Arguments
47+
48+
- `initial_account_balance=1_000_000`.
49+
"""
50+
function StockTradingEnv(;
51+
initial_account_balance=1_000_000f0,
52+
features=nothing,
53+
prices=nothing,
54+
first_day=nothing,
55+
last_day=nothing,
56+
HMAX_NORMALIZE = 100f0,
57+
TRANSACTION_FEE_PERCENT = 0.001f0,
58+
REWARD_SCALING = 1f-4
59+
)
60+
prices = isnothing(prices) ? load_default_stock_data("prices.csv") : prices
61+
features = isnothing(features) ? load_default_stock_data("features.csv") : features
62+
63+
@assert size(prices, 2) == size(features, 2)
64+
65+
first_day = isnothing(first_day) ? 1 : first_day
66+
last_day = isnothing(last_day) ? size(prices, 2) : last_day
67+
day = first_day
68+
69+
# [balance, stock_prices..., stock_holds..., features...]
70+
state = zeros(Float32, 1 + size(prices, 1) * 2 + size(features, 1))
71+
72+
env = StockTradingEnv(
73+
features,
74+
prices,
75+
HMAX_NORMALIZE,
76+
TRANSACTION_FEE_PERCENT,
77+
REWARD_SCALING,
78+
initial_account_balance,
79+
state,
80+
0f0,
81+
day,
82+
first_day,
83+
last_day,
84+
0f0
85+
)
86+
87+
_balance(env)[] = initial_account_balance
88+
_prices(env) .= @view prices[:, day]
89+
_features(env) .= @view features[:, day]
90+
91+
env
92+
end
93+
94+
function (env::StockTradingEnv)(actions)
95+
init_asset = _total_asset(env)
96+
97+
# sell first
98+
for (i, s) in enumerate(actions)
99+
if s < 0
100+
sell = min(-env.HMAX_NORMALIZE * s, _holds(env)[i])
101+
_holds(env)[i] -= sell
102+
gain = _prices(env)[i] * sell
103+
cost = gain * env.TRANSACTION_FEE_PERCENT
104+
_balance(env)[] += gain - cost
105+
env.total_cost += cost
106+
end
107+
end
108+
109+
# then buy
110+
# better to shuffle?
111+
for (i,b) in enumerate(actions)
112+
if b > 0
113+
max_buy = div(_balance(env)[], _prices(env)[i])
114+
buy = min(b*env.HMAX_NORMALIZE, max_buy)
115+
_holds(env)[i] += buy
116+
deduction = buy * _prices(env)[i]
117+
cost = deduction * env.TRANSACTION_FEE_PERCENT
118+
_balance(env)[] -= deduction + cost
119+
env.total_cost += cost
120+
end
121+
end
122+
123+
env.day += 1
124+
_prices(env) .= @view env.prices[:, env.day]
125+
_features(env) .= @view env.features[:, env.day]
126+
127+
env.daily_reward = _total_asset(env) - init_asset
128+
end
129+
130+
RLBase.reward(env::StockTradingEnv) = env.daily_reward * env.REWARD_SCALING
131+
RLBase.is_terminated(env::StockTradingEnv) = env.day >= env.last_day
132+
RLBase.state(env::StockTradingEnv) = env.state
133+
134+
function RLBase.reset!(env::StockTradingEnv)
135+
env.day = env.first_day
136+
_balance(env)[] = env.initial_account_balance
137+
_prices(env) .= @view env.prices[:, env.day]
138+
_features(env) .= @view env.features[:, env.day]
139+
env.total_cost = 0.
140+
env.daily_reward = 0.
141+
end
142+
143+
RLBase.state_space(env::StockTradingEnv) = Space(fill(-Inf32..Inf32, length(state(env))))
144+
RLBase.action_space(env::StockTradingEnv) = Space(fill(-1f0..1f0, length(_holds(env))))
145+
146+
RLBase.ChanceStyle(::StockTradingEnv) = DETERMINISTIC
147+
148+
# wrapper
149+
150+
struct StockTradingEnvWithTurbulence{E<:StockTradingEnv} <: AbstractEnvWrapper
151+
env::E
152+
turbulences::Vector{Float64}
153+
turbulence_threshold::Float64
154+
end
155+
156+
function StockTradingEnvWithTurbulence(;
157+
turbulence_threshold=140.,
158+
turbulences=nothing,
159+
kw...
160+
)
161+
turbulences = isnothing(turbulences) && load_default_stock_data("turbulence.csv")
162+
163+
StockTradingEnvWithTurbulence(
164+
StockTradingEnv(;kw...),
165+
turbulences,
166+
turbulence_threshold
167+
)
168+
end
169+
170+
function (w::StockTradingEnvWithTurbulence)(actions)
171+
if w.turbulences[w.env.day] >= w.turbulence_threshold
172+
actions .= ifelse.(actions .< 0, -Inf32, 0)
173+
end
174+
w.env(actions)
175+
end

src/ReinforcementLearningEnvironments/src/environments/examples/examples.jl

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ include("CartPoleEnv.jl")
1111
include("MountainCarEnv.jl")
1212
include("PendulumEnv.jl")
1313
include("BitFlippingEnv.jl")
14+
include("StockTradingEnv.jl")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@testset "StockTradingEnv" begin
2+
3+
env = StockTradingEnvWithTurbulence()
4+
5+
RLBase.test_interfaces!(env)
6+
RLBase.test_runnable!(env)
7+
end
8+

0 commit comments

Comments
 (0)