Skip to content

Commit efffb28

Browse files
yaxuan999you-n-g
andauthored
added KRNN and Sandwich models and their example results based on Alpha360 (#1414)
* Update README.md updated the result of KRNN and Sandwich models based on Alpha360 * Update README.md * Update README.md * Add files via upload * Update README.md * Update README.md * Update README.md * Add files via upload * Delete pytorch_krnn.py * Delete pytorch_sandwich.py * Add files via upload * Update pytorch_sandwich.py * Update pytorch_krnn.py * Update pytorch_sandwich.py * Update pytorch_krnn.py * Update README.md * Update README.md * Update requirements.txt * Update requirements.txt * Update README.md * Update README.md * Update pytorch_sandwich.py * Update link on index --------- Co-authored-by: Young <[email protected]>
1 parent 19a0eb7 commit efffb28

10 files changed

+1096
-0
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Recent released features
1212
| Feature | Status |
1313
| -- | ------ |
14+
| KRNN and Sandwich models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1414/) on May 26, 2023 |
1415
| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |
1516
| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|
1617
| HIST and IGMTF models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1040) on Apr 10, 2022 |
@@ -353,6 +354,8 @@ Here is a list of models built on `Qlib`.
353354
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
354355
- [IGMTF based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/IGMTF/)
355356
- [HIST based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/HIST/)
357+
- [KRNN based on pytorch](examples/benchmarks/KRNN/)
358+
- [Sandwich based on pytorch](examples/benchmarks/Sandwich/)
356359

357360
Your PR of new Quant models is highly welcomed.
358361

examples/benchmarks/KRNN/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# KRNN
2+
* Code: [https://github.com/microsoft/FOST/blob/main/fostool/model/krnn.py](https://github.com/microsoft/FOST/blob/main/fostool/model/krnn.py)
3+
4+
5+
# Introductions about the settings/configs.
6+
* Torch_geometric is used in the original model in FOST, but we didn't use it.
7+
* make use your CUDA version matches the torch version to allow the usage of GPU, we use CUDA==10.2 and torch.__version__==1.12.1
8+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
numpy==1.23.4
2+
pandas==1.5.2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
qlib_init:
2+
provider_uri: "~/.qlib/qlib_data/cn_data"
3+
region: cn
4+
market: &market csi300
5+
benchmark: &benchmark SH000300
6+
data_handler_config: &data_handler_config
7+
start_time: 2008-01-01
8+
end_time: 2020-08-01
9+
fit_start_time: 2008-01-01
10+
fit_end_time: 2014-12-31
11+
instruments: *market
12+
infer_processors:
13+
- class: RobustZScoreNorm
14+
kwargs:
15+
fields_group: feature
16+
clip_outlier: true
17+
- class: Fillna
18+
kwargs:
19+
fields_group: feature
20+
learn_processors:
21+
- class: DropnaLabel
22+
- class: CSRankNorm
23+
kwargs:
24+
fields_group: label
25+
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
26+
port_analysis_config: &port_analysis_config
27+
strategy:
28+
class: TopkDropoutStrategy
29+
module_path: qlib.contrib.strategy
30+
kwargs:
31+
signal:
32+
- <MODEL>
33+
- <DATASET>
34+
topk: 50
35+
n_drop: 5
36+
backtest:
37+
start_time: 2017-01-01
38+
end_time: 2020-08-01
39+
account: 100000000
40+
benchmark: *benchmark
41+
exchange_kwargs:
42+
limit_threshold: 0.095
43+
deal_price: close
44+
open_cost: 0.0005
45+
close_cost: 0.0015
46+
min_cost: 5
47+
task:
48+
model:
49+
class: KRNN
50+
module_path: qlib.contrib.model.pytorch_krnn
51+
kwargs:
52+
fea_dim: 6
53+
cnn_dim: 8
54+
cnn_kernel_size: 3
55+
rnn_dim: 8
56+
rnn_dups: 2
57+
rnn_layers: 2
58+
n_epochs: 200
59+
lr: 0.001
60+
early_stop: 20
61+
batch_size: 2000
62+
metric: loss
63+
GPU: 0
64+
dataset:
65+
class: DatasetH
66+
module_path: qlib.data.dataset
67+
kwargs:
68+
handler:
69+
class: Alpha360
70+
module_path: qlib.contrib.data.handler
71+
kwargs: *data_handler_config
72+
segments:
73+
train: [2008-01-01, 2014-12-31]
74+
valid: [2015-01-01, 2016-12-31]
75+
test: [2017-01-01, 2020-08-01]
76+
record:
77+
- class: SignalRecord
78+
module_path: qlib.workflow.record_temp
79+
kwargs:
80+
model: <MODEL>
81+
dataset: <DATASET>
82+
- class: SigAnaRecord
83+
module_path: qlib.workflow.record_temp
84+
kwargs:
85+
ana_long_short: False
86+
ann_scaler: 252
87+
- class: PortAnaRecord
88+
module_path: qlib.workflow.record_temp
89+
kwargs:
90+
config: *port_analysis_config
91+

examples/benchmarks/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
6868
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
6969
| IGMTF(Wentao Xu, et al.) | Alpha360 | 0.0480±0.00 | 0.3589±0.02 | 0.0606±0.00 | 0.4773±0.01 | 0.0946±0.02 | 1.3509±0.25 | -0.0716±0.02 |
7070
| HIST(Wentao Xu, et al.) | Alpha360 | 0.0522±0.00 | 0.3530±0.01 | 0.0667±0.00 | 0.4576±0.01 | 0.0987±0.02 | 1.3726±0.27 | -0.0681±0.01 |
71+
| KRNN | Alpha360 | 0.0173±0.01 | 0.1210±0.06 | 0.0270±0.01 | 0.2018±0.04 | -0.0465±0.05 | -0.5415±0.62 | -0.2919±0.13 |
72+
| Sandwich | Alpha360 | 0.0258±0.00 | 0.1924±0.04 | 0.0337±0.00 | 0.2624±0.03 | 0.0005±0.03 | 0.0001±0.33 | -0.1752±0.05 |
7173

7274

7375
- The selected 20 features are based on the feature importance of a lightgbm-based model.
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Sandwich
2+
* Code: [https://github.com/microsoft/FOST/blob/main/fostool/model/sandwich.py](https://github.com/microsoft/FOST/blob/main/fostool/model/sandwich.py)
3+
4+
5+
# Introductions about the settings/configs.
6+
* Torch_geometric is used in the original model in FOST, but we didn't use it.
7+
make use your CUDA version matches the torch version to allow the usage of GPU, we use CUDA==10.2 and torch.version==1.12.1
8+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
numpy==1.23.4
2+
pandas==1.5.2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
qlib_init:
2+
provider_uri: "~/.qlib/qlib_data/cn_data"
3+
region: cn
4+
market: &market csi300
5+
benchmark: &benchmark SH000300
6+
data_handler_config: &data_handler_config
7+
start_time: 2008-01-01
8+
end_time: 2020-08-01
9+
fit_start_time: 2008-01-01
10+
fit_end_time: 2014-12-31
11+
instruments: *market
12+
infer_processors:
13+
- class: RobustZScoreNorm
14+
kwargs:
15+
fields_group: feature
16+
clip_outlier: true
17+
- class: Fillna
18+
kwargs:
19+
fields_group: feature
20+
learn_processors:
21+
- class: DropnaLabel
22+
- class: CSRankNorm
23+
kwargs:
24+
fields_group: label
25+
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
26+
port_analysis_config: &port_analysis_config
27+
strategy:
28+
class: TopkDropoutStrategy
29+
module_path: qlib.contrib.strategy
30+
kwargs:
31+
signal:
32+
- <MODEL>
33+
- <DATASET>
34+
topk: 50
35+
n_drop: 5
36+
backtest:
37+
start_time: 2017-01-01
38+
end_time: 2020-08-01
39+
account: 100000000
40+
benchmark: *benchmark
41+
exchange_kwargs:
42+
limit_threshold: 0.095
43+
deal_price: close
44+
open_cost: 0.0005
45+
close_cost: 0.0015
46+
min_cost: 5
47+
task:
48+
model:
49+
class: Sandwich
50+
module_path: qlib.contrib.model.pytorch_sandwich
51+
kwargs:
52+
fea_dim: 6
53+
cnn_dim_1: 16
54+
cnn_dim_2: 16
55+
cnn_kernel_size: 3
56+
rnn_dim_1: 8
57+
rnn_dim_2: 8
58+
rnn_dups: 2
59+
rnn_layers: 2
60+
n_epochs: 200
61+
lr: 0.001
62+
early_stop: 20
63+
batch_size: 2000
64+
metric: loss
65+
GPU: 0
66+
dataset:
67+
class: DatasetH
68+
module_path: qlib.data.dataset
69+
kwargs:
70+
handler:
71+
class: Alpha360
72+
module_path: qlib.contrib.data.handler
73+
kwargs: *data_handler_config
74+
segments:
75+
train: [2008-01-01, 2014-12-31]
76+
valid: [2015-01-01, 2016-12-31]
77+
test: [2017-01-01, 2020-08-01]
78+
record:
79+
- class: SignalRecord
80+
module_path: qlib.workflow.record_temp
81+
kwargs:
82+
model: <MODEL>
83+
dataset: <DATASET>
84+
- class: SigAnaRecord
85+
module_path: qlib.workflow.record_temp
86+
kwargs:
87+
ana_long_short: False
88+
ann_scaler: 252
89+
- class: PortAnaRecord
90+
module_path: qlib.workflow.record_temp
91+
kwargs:
92+
config: *port_analysis_config
93+

0 commit comments

Comments
 (0)