Skip to content

Commit 4f26eec

Browse files
authored
Fix encoder transformer and remove absolute index encoder (#1257)
* fixed issue when using transformer and only cyclic encodings * fixed transformers issue with incosistent input TimeSeries lists * added tests for transformers * removed absolute encoder in favor of relative encoder * updated docs and IndexGenerator tests * fixed explainability tests with encoder updates
1 parent 5424a6d commit 4f26eec

18 files changed

+361
-298
lines changed

darts/models/forecasting/block_rnn_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(
239239
add_encoders={
240240
'cyclic': {'future': ['month']},
241241
'datetime_attribute': {'future': ['hour', 'dayofweek']},
242-
'position': {'past': ['absolute'], 'future': ['relative']},
242+
'position': {'past': ['relative'], 'future': ['relative']},
243243
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
244244
'transformer': Scaler()
245245
}

darts/models/forecasting/catboost_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
add_encoders={
6565
'cyclic': {'future': ['month']},
6666
'datetime_attribute': {'future': ['hour', 'dayofweek']},
67-
'position': {'past': ['absolute'], 'future': ['relative']},
67+
'position': {'past': ['relative'], 'future': ['relative']},
6868
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
6969
'transformer': Scaler()
7070
}

darts/models/forecasting/gradient_boosted_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
add_encoders={
6969
'cyclic': {'future': ['month']},
7070
'datetime_attribute': {'future': ['hour', 'dayofweek']},
71-
'position': {'past': ['absolute'], 'future': ['relative']},
71+
'position': {'past': ['relative'], 'future': ['relative']},
7272
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
7373
'transformer': Scaler()
7474
}

darts/models/forecasting/linear_regression_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
add_encoders={
6767
'cyclic': {'future': ['month']},
6868
'datetime_attribute': {'future': ['hour', 'dayofweek']},
69-
'position': {'past': ['absolute'], 'future': ['relative']},
69+
'position': {'past': ['relative'], 'future': ['relative']},
7070
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
7171
'transformer': Scaler()
7272
}

darts/models/forecasting/nbeats.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def __init__(
659659
add_encoders={
660660
'cyclic': {'future': ['month']},
661661
'datetime_attribute': {'future': ['hour', 'dayofweek']},
662-
'position': {'past': ['absolute'], 'future': ['relative']},
662+
'position': {'past': ['relative'], 'future': ['relative']},
663663
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
664664
'transformer': Scaler()
665665
}

darts/models/forecasting/nhits.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def __init__(
595595
add_encoders={
596596
'cyclic': {'future': ['month']},
597597
'datetime_attribute': {'future': ['hour', 'dayofweek']},
598-
'position': {'past': ['absolute'], 'future': ['relative']},
598+
'position': {'past': ['relative'], 'future': ['relative']},
599599
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
600600
'transformer': Scaler()
601601
}

darts/models/forecasting/random_forest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
add_encoders={
7272
'cyclic': {'future': ['month']},
7373
'datetime_attribute': {'future': ['hour', 'dayofweek']},
74-
'position': {'past': ['absolute'], 'future': ['relative']},
74+
'position': {'past': ['relative'], 'future': ['relative']},
7575
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
7676
'transformer': Scaler()
7777
}

darts/models/forecasting/regression_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
add_encoders={
8888
'cyclic': {'future': ['month']},
8989
'datetime_attribute': {'future': ['hour', 'dayofweek']},
90-
'position': {'past': ['absolute'], 'future': ['relative']},
90+
'position': {'past': ['relative'], 'future': ['relative']},
9191
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
9292
'transformer': Scaler()
9393
}

darts/models/forecasting/rnn_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def __init__(
318318
add_encoders={
319319
'cyclic': {'future': ['month']},
320320
'datetime_attribute': {'future': ['hour', 'dayofweek']},
321-
'position': {'past': ['absolute'], 'future': ['relative']},
321+
'position': {'past': ['relative'], 'future': ['relative']},
322322
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
323323
'transformer': Scaler()
324324
}

darts/models/forecasting/tcn_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def __init__(
360360
add_encoders={
361361
'cyclic': {'future': ['month']},
362362
'datetime_attribute': {'future': ['hour', 'dayofweek']},
363-
'position': {'past': ['absolute'], 'future': ['relative']},
363+
'position': {'past': ['relative'], 'future': ['relative']},
364364
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
365365
'transformer': Scaler()
366366
}

darts/models/forecasting/tft_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,7 @@ def __init__(
799799
add_encoders={
800800
'cyclic': {'future': ['month']},
801801
'datetime_attribute': {'future': ['hour', 'dayofweek']},
802-
'position': {'past': ['absolute'], 'future': ['relative']},
802+
'position': {'past': ['relative'], 'future': ['relative']},
803803
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
804804
'transformer': Scaler()
805805
}

darts/models/forecasting/torch_forecasting_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def __init__(
187187
add_encoders={
188188
'cyclic': {'future': ['month']},
189189
'datetime_attribute': {'future': ['hour', 'dayofweek']},
190-
'position': {'past': ['absolute'], 'future': ['relative']},
190+
'position': {'past': ['relative'], 'future': ['relative']},
191191
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
192192
'transformer': Scaler()
193193
}

darts/models/forecasting/transformer_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def __init__(
448448
add_encoders={
449449
'cyclic': {'future': ['month']},
450450
'datetime_attribute': {'future': ['hour', 'dayofweek']},
451-
'position': {'past': ['absolute'], 'future': ['relative']},
451+
'position': {'past': ['relative'], 'future': ['relative']},
452452
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
453453
'transformer': Scaler()
454454
}

darts/tests/explainability/test_shap_explainer.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class ShapExplainerTestCase(DartsBaseTestClass):
2828
add_encoders = {
2929
"cyclic": {"past": ["month", "day"]},
3030
"datetime_attribute": {"future": ["hour", "dayofweek"]},
31-
"position": {"past": ["absolute"], "future": ["relative"]},
31+
"position": {"past": ["relative"], "future": ["relative"]},
3232
"custom": {"past": [lambda idx: (idx.year - 1950) / 50]},
3333
"transformer": Scaler(scaler),
3434
}
@@ -281,14 +281,14 @@ def test_explain(self):
281281
self.assertEqual(len(explanation), 537)
282282

283283
# list of foregrounds: encoders have to be corrected first.
284-
# results = shap_explain.explain(
285-
# foreground_series=[self.target_ts, self.target_ts[:100]],
286-
# foreground_past_covariates=[self.past_cov_ts, self.past_cov_ts[:40]],
287-
# foreground_future_covariates=[self.fut_cov_ts, self.fut_cov_ts[:40]],
288-
# )
289-
# ts_res = results.get_explanation(horizon=2, component="power")
290-
291-
# self.assertEqual(len(ts_res), 2)
284+
results = shap_explain.explain(
285+
foreground_series=[self.target_ts, self.target_ts[:100]],
286+
foreground_past_covariates=[self.past_cov_ts, self.past_cov_ts[:40]],
287+
foreground_future_covariates=[self.fut_cov_ts, self.fut_cov_ts[:40]],
288+
)
289+
ts_res = results.get_explanation(horizon=2, component="power")
290+
291+
self.assertEqual(len(ts_res), 2)
292292
# explain with a new foreground, minimum required. We should obtain one
293293
# timeseries with only one time element
294294
results = shap_explain.explain(
@@ -322,30 +322,30 @@ def test_explain(self):
322322
"0_past_cov_lag-3",
323323
"1_past_cov_lag-3",
324324
"2_past_cov_lag-3",
325-
"month_sin_past_cov_lag-3",
326-
"month_cos_past_cov_lag-3",
327-
"day_sin_past_cov_lag-3",
328-
"day_cos_past_cov_lag-3",
329-
"absolute_idx_past_cov_lag-3",
330-
"custom_past_cov_lag-3",
325+
"darts_enc_pc_cyc_month_sin_past_cov_lag-3",
326+
"darts_enc_pc_cyc_month_cos_past_cov_lag-3",
327+
"darts_enc_pc_cyc_day_sin_past_cov_lag-3",
328+
"darts_enc_pc_cyc_day_cos_past_cov_lag-3",
329+
"darts_enc_pc_pos_relative_past_cov_lag-3",
330+
"darts_enc_pc_cus_custom_past_cov_lag-3",
331331
"0_past_cov_lag-2",
332332
"1_past_cov_lag-2",
333333
"2_past_cov_lag-2",
334-
"month_sin_past_cov_lag-2",
335-
"month_cos_past_cov_lag-2",
336-
"day_sin_past_cov_lag-2",
337-
"day_cos_past_cov_lag-2",
338-
"absolute_idx_past_cov_lag-2",
339-
"custom_past_cov_lag-2",
334+
"darts_enc_pc_cyc_month_sin_past_cov_lag-2",
335+
"darts_enc_pc_cyc_month_cos_past_cov_lag-2",
336+
"darts_enc_pc_cyc_day_sin_past_cov_lag-2",
337+
"darts_enc_pc_cyc_day_cos_past_cov_lag-2",
338+
"darts_enc_pc_pos_relative_past_cov_lag-2",
339+
"darts_enc_pc_cus_custom_past_cov_lag-2",
340340
"0_past_cov_lag-1",
341341
"1_past_cov_lag-1",
342342
"2_past_cov_lag-1",
343-
"month_sin_past_cov_lag-1",
344-
"month_cos_past_cov_lag-1",
345-
"day_sin_past_cov_lag-1",
346-
"day_cos_past_cov_lag-1",
347-
"absolute_idx_past_cov_lag-1",
348-
"custom_past_cov_lag-1",
343+
"darts_enc_pc_cyc_month_sin_past_cov_lag-1",
344+
"darts_enc_pc_cyc_month_cos_past_cov_lag-1",
345+
"darts_enc_pc_cyc_day_sin_past_cov_lag-1",
346+
"darts_enc_pc_cyc_day_cos_past_cov_lag-1",
347+
"darts_enc_pc_pos_relative_past_cov_lag-1",
348+
"darts_enc_pc_cus_custom_past_cov_lag-1",
349349
"0_fut_cov_lag0",
350350
"1_fut_cov_lag0",
351351
"hour_fut_cov_lag0",

darts/tests/models/forecasting/test_covariate_index_generators.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,23 @@ class CovariateIndexGeneratorTestCase(DartsBaseTestClass):
7474
def helper_test_index_types(self, ig: CovariateIndexGenerator):
7575
"""test the index type of generated index"""
7676
# pd.DatetimeIndex
77-
idx = ig.generate_train_series(self.target_time, self.cov_time_train)
77+
idx, _ = ig.generate_train_series(self.target_time, self.cov_time_train)
7878
self.assertTrue(isinstance(idx, pd.DatetimeIndex))
79-
idx = ig.generate_inference_series(
79+
idx, _ = ig.generate_inference_series(
8080
self.n_short, self.target_time, self.cov_time_inf_short
8181
)
8282
self.assertTrue(isinstance(idx, pd.DatetimeIndex))
83-
idx = ig.generate_train_series(self.target_time, None)
83+
idx, _ = ig.generate_train_series(self.target_time, None)
8484
self.assertTrue(isinstance(idx, pd.DatetimeIndex))
8585

8686
# pd.RangeIndex
87-
idx = ig.generate_train_series(self.target_int, self.cov_int_train)
87+
idx, _ = ig.generate_train_series(self.target_int, self.cov_int_train)
8888
self.assertTrue(isinstance(idx, pd.RangeIndex))
89-
idx = ig.generate_inference_series(
89+
idx, _ = ig.generate_inference_series(
9090
self.n_short, self.target_int, self.cov_int_inf_short
9191
)
9292
self.assertTrue(isinstance(idx, pd.RangeIndex))
93-
idx = ig.generate_train_series(self.target_int, None)
93+
idx, _ = ig.generate_train_series(self.target_int, None)
9494
self.assertTrue(isinstance(idx, pd.RangeIndex))
9595

9696
def helper_test_index_generator_train(self, ig: CovariateIndexGenerator):
@@ -100,24 +100,24 @@ def helper_test_index_generator_train(self, ig: CovariateIndexGenerator):
100100
"""
101101
# pd.DatetimeIndex
102102
# generated index must be equal to input covariate index
103-
idx = ig.generate_train_series(self.target_time, self.cov_time_train)
103+
idx, _ = ig.generate_train_series(self.target_time, self.cov_time_train)
104104
self.assertTrue(idx.equals(self.cov_time_train.time_index))
105105
# generated index must be equal to input covariate index
106-
idx = ig.generate_train_series(self.target_time, self.cov_time_train_short)
106+
idx, _ = ig.generate_train_series(self.target_time, self.cov_time_train_short)
107107
self.assertTrue(idx.equals(self.cov_time_train_short.time_index))
108108
# generated index must be equal to input target index when no covariates are defined
109-
idx = ig.generate_train_series(self.target_time, None)
109+
idx, _ = ig.generate_train_series(self.target_time, None)
110110
self.assertTrue(idx.equals(self.cov_time_train.time_index))
111111

112112
# integer index
113113
# generated index must be equal to input covariate index
114-
idx = ig.generate_train_series(self.target_int, self.cov_int_train)
114+
idx, _ = ig.generate_train_series(self.target_int, self.cov_int_train)
115115
self.assertTrue(idx.equals(self.cov_int_train.time_index))
116116
# generated index must be equal to input covariate index
117-
idx = ig.generate_train_series(self.target_time, self.cov_int_train_short)
117+
idx, _ = ig.generate_train_series(self.target_time, self.cov_int_train_short)
118118
self.assertTrue(idx.equals(self.cov_int_train_short.time_index))
119119
# generated index must be equal to input target index when no covariates are defined
120-
idx = ig.generate_train_series(self.target_int, None)
120+
idx, _ = ig.generate_train_series(self.target_int, None)
121121
self.assertTrue(idx.equals(self.cov_int_train.time_index))
122122

123123
def helper_test_index_generator_inference(self, ig, is_past=False):
@@ -134,7 +134,7 @@ def helper_test_index_generator_inference(self, ig, is_past=False):
134134
"""
135135

136136
# check generated inference index without passing covariates when n <= output_chunk_length
137-
idx = ig.generate_inference_series(self.n_short, self.target_time, None)
137+
idx, _ = ig.generate_inference_series(self.n_short, self.target_time, None)
138138
if is_past:
139139
n_out = self.input_chunk_length
140140
last_idx = self.target_time.end_time()
@@ -146,7 +146,7 @@ def helper_test_index_generator_inference(self, ig, is_past=False):
146146
self.assertTrue(idx[-1] == last_idx)
147147

148148
# check generated inference index without passing covariates when n > output_chunk_length
149-
idx = ig.generate_inference_series(self.n_long, self.target_time, None)
149+
idx, _ = ig.generate_inference_series(self.n_long, self.target_time, None)
150150
if is_past:
151151
n_out = self.input_chunk_length + self.n_long - self.output_chunk_length
152152
last_idx = (
@@ -160,19 +160,19 @@ def helper_test_index_generator_inference(self, ig, is_past=False):
160160
self.assertTrue(len(idx) == n_out)
161161
self.assertTrue(idx[-1] == last_idx)
162162

163-
idx = ig.generate_inference_series(
163+
idx, _ = ig.generate_inference_series(
164164
self.n_short, self.target_time, self.cov_time_inf_short
165165
)
166166
self.assertTrue(idx.equals(self.cov_time_inf_short.time_index))
167-
idx = ig.generate_inference_series(
167+
idx, _ = ig.generate_inference_series(
168168
self.n_long, self.target_time, self.cov_time_inf_long
169169
)
170170
self.assertTrue(idx.equals(self.cov_time_inf_long.time_index))
171-
idx = ig.generate_inference_series(
171+
idx, _ = ig.generate_inference_series(
172172
self.n_short, self.target_int, self.cov_int_inf_short
173173
)
174174
self.assertTrue(idx.equals(self.cov_int_inf_short.time_index))
175-
idx = ig.generate_inference_series(
175+
idx, _ = ig.generate_inference_series(
176176
self.n_long, self.target_int, self.cov_int_inf_long
177177
)
178178
self.assertTrue(idx.equals(self.cov_int_inf_long.time_index))

0 commit comments

Comments
 (0)