Skip to content

Commit f5895d9

Browse files
JanFidormadtoinoudennisbader
authored andcommitted
Fix/ensemble historical forecasts (unit8co#1616)
* add correct extreme_lags override and test * add required extreme_lags override * delete logging print * change lag priorities * add a test + use switch to tuple * fix extreme lags from other PR * make RegressionEnsembleModel work * small unit test fix --------- Co-authored-by: madtoinou <[email protected]> Co-authored-by: Dennis Bader <[email protected]>
1 parent 96f35e6 commit f5895d9

10 files changed

+266
-20
lines changed

darts/models/forecasting/baselines.py

-4
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,6 @@ def predict(self, n: int, num_samples: int = 1, verbose: bool = False):
7979
forecast = np.array([self.last_k_vals[i % self.K, :] for i in range(n)])
8080
return self._build_forecast_series(forecast)
8181

82-
@property
83-
def extreme_lags(self):
84-
return -self.K, 0, None, None, None, None
85-
8682

8783
class NaiveDrift(LocalForecastingModel):
8884
def __init__(self):

darts/models/forecasting/ensemble_model.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from abc import abstractmethod
66
from functools import reduce
7-
from typing import List, Optional, Sequence, Union
7+
from typing import List, Optional, Sequence, Tuple, Union
88

99
from darts.logging import get_logger, raise_if, raise_if_not
1010
from darts.models.forecasting.forecasting_model import (
@@ -201,5 +201,31 @@ def min_train_series_length(self) -> int:
201201
def min_train_samples(self) -> int:
202202
return max(model.min_train_samples for model in self.models)
203203

204+
@property
205+
def extreme_lags(
206+
self,
207+
) -> Tuple[
208+
Optional[int],
209+
Optional[int],
210+
Optional[int],
211+
Optional[int],
212+
Optional[int],
213+
Optional[int],
214+
]:
215+
def find_max_lag_or_none(lag_id, aggregator) -> Optional[int]:
216+
max_lag = None
217+
for model in self.models:
218+
curr_lag = model.extreme_lags[lag_id]
219+
if max_lag is None:
220+
max_lag = curr_lag
221+
elif curr_lag is not None:
222+
max_lag = aggregator(max_lag, curr_lag)
223+
return max_lag
224+
225+
lag_aggregators = (min, max, min, max, min, max)
226+
return tuple(
227+
find_max_lag_or_none(i, agg) for i, agg in enumerate(lag_aggregators)
228+
)
229+
204230
def _is_probabilistic(self) -> bool:
205231
return all([model._is_probabilistic() for model in self.models])

darts/models/forecasting/forecasting_model.py

+36-6
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def min_train_samples(self) -> int:
281281
return 1
282282

283283
@property
284+
@abstractmethod
284285
def extreme_lags(
285286
self,
286287
) -> Tuple[
@@ -335,8 +336,7 @@ def extreme_lags(
335336
>>> model.extreme_lags
336337
(-10, 6, None, None, 4, 6)
337338
"""
338-
339-
return -1, 0, None, None, None, None
339+
pass
340340

341341
@property
342342
def _training_sample_time_index_length(self) -> int:
@@ -1914,6 +1914,23 @@ def fit(self, series: TimeSeries) -> "LocalForecastingModel":
19141914
super().fit(series)
19151915
series._assert_deterministic()
19161916

1917+
@property
1918+
def extreme_lags(
1919+
self,
1920+
) -> Tuple[
1921+
Optional[int],
1922+
Optional[int],
1923+
Optional[int],
1924+
Optional[int],
1925+
Optional[int],
1926+
Optional[int],
1927+
]:
1928+
# TODO: LocalForecastingModels do not yet handle extreme lags properly. Especially
1929+
# TransferableFutureCovariatesLocalForecastingModel, where there is a difference between fit and predict mode)
1930+
# do not yet. In general, Local models train on the entire series (input=output), different to Global models
1931+
# that use an input to predict an output.
1932+
return -self.min_train_series_length, -1, None, None, None, None
1933+
19171934

19181935
class GlobalForecastingModel(ForecastingModel, ABC):
19191936
"""The base class for "global" forecasting models, handling several time series and optional covariates.
@@ -2315,6 +2332,23 @@ def _supress_generate_predict_encoding(self) -> bool:
23152332
"""Controls wether encodings should be generated in :func:`FutureCovariatesLocalForecastingModel.predict()``"""
23162333
return False
23172334

2335+
@property
2336+
def extreme_lags(
2337+
self,
2338+
) -> Tuple[
2339+
Optional[int],
2340+
Optional[int],
2341+
Optional[int],
2342+
Optional[int],
2343+
Optional[int],
2344+
Optional[int],
2345+
]:
2346+
# TODO: LocalForecastingModels do not yet handle extreme lags properly. Especially
2347+
# TransferableFutureCovariatesLocalForecastingModel, where there is a difference between fit and predict mode)
2348+
# do not yet. In general, Local models train on the entire series (input=output), different to Global models
2349+
# that use an input to predict an output.
2350+
return -self.min_train_series_length, -1, None, None, 0, 0
2351+
23182352

23192353
class TransferableFutureCovariatesLocalForecastingModel(
23202354
FutureCovariatesLocalForecastingModel, ABC
@@ -2492,7 +2526,3 @@ def _supports_non_retrainable_historical_forecasts(self) -> bool:
24922526
@property
24932527
def _supress_generate_predict_encoding(self) -> bool:
24942528
return True
2495-
2496-
@property
2497-
def extreme_lags(self):
2498-
return -1, 0, None, None, 0, 0

darts/models/forecasting/regression_ensemble_model.py

+14
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,17 @@ def ensemble(
170170
for serie, prediction in zip(series, predictions)
171171
]
172172
return seq2series(ensembled) if is_single_series else ensembled
173+
174+
@property
175+
def extreme_lags(
176+
self,
177+
) -> Tuple[
178+
Optional[int],
179+
Optional[int],
180+
Optional[int],
181+
Optional[int],
182+
Optional[int],
183+
Optional[int],
184+
]:
185+
extreme_lags_ = super().extreme_lags
186+
return extreme_lags_[0] - self.train_n_points, *extreme_lags_[1:]

darts/models/forecasting/regression_model.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,16 @@ def _model_encoder_settings(
274274
)
275275

276276
@property
277-
def extreme_lags(self):
277+
def extreme_lags(
278+
self,
279+
) -> Tuple[
280+
Optional[int],
281+
Optional[int],
282+
Optional[int],
283+
Optional[int],
284+
Optional[int],
285+
Optional[int],
286+
]:
278287
min_target_lag = self.lags.get("target")[0] if "target" in self.lags else None
279288
max_target_lag = self.output_chunk_length - 1
280289
min_past_cov_lag = self.lags.get("past")[0] if "past" in self.lags else None
@@ -285,7 +294,6 @@ def extreme_lags(self):
285294
max_future_cov_lag = (
286295
self.lags.get("future")[-1] if "future" in self.lags else None
287296
)
288-
289297
return (
290298
min_target_lag,
291299
max_target_lag,

darts/models/forecasting/theta.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import math
7-
from typing import List, Optional
7+
from typing import List, Optional, Tuple
88

99
import numpy as np
1010
import statsmodels.tsa.holtwinters as hw
@@ -179,6 +179,19 @@ def min_train_series_length(self) -> int:
179179
else:
180180
return 3
181181

182+
@property
183+
def extreme_lags(
184+
self,
185+
) -> Tuple[
186+
Optional[int],
187+
Optional[int],
188+
Optional[int],
189+
Optional[int],
190+
Optional[int],
191+
Optional[int],
192+
]:
193+
return -self.min_train_series_length, 0, None, None, None, None
194+
182195

183196
class FourTheta(LocalForecastingModel):
184197
def __init__(

darts/models/forecasting/torch_forecasting_model.py

+50-5
Original file line numberDiff line numberDiff line change
@@ -2097,7 +2097,16 @@ def _model_encoder_settings(
20972097
)
20982098

20992099
@property
2100-
def extreme_lags(self):
2100+
def extreme_lags(
2101+
self,
2102+
) -> Tuple[
2103+
Optional[int],
2104+
Optional[int],
2105+
Optional[int],
2106+
Optional[int],
2107+
Optional[int],
2108+
Optional[int],
2109+
]:
21012110
return (
21022111
-self.input_chunk_length,
21032112
self.output_chunk_length - 1,
@@ -2186,7 +2195,16 @@ def _model_encoder_settings(
21862195
)
21872196

21882197
@property
2189-
def extreme_lags(self):
2198+
def extreme_lags(
2199+
self,
2200+
) -> Tuple[
2201+
Optional[int],
2202+
Optional[int],
2203+
Optional[int],
2204+
Optional[int],
2205+
Optional[int],
2206+
Optional[int],
2207+
]:
21902208
return (
21912209
-self.input_chunk_length,
21922210
self.output_chunk_length - 1,
@@ -2266,7 +2284,16 @@ def _model_encoder_settings(
22662284
)
22672285

22682286
@property
2269-
def extreme_lags(self):
2287+
def extreme_lags(
2288+
self,
2289+
) -> Tuple[
2290+
Optional[int],
2291+
Optional[int],
2292+
Optional[int],
2293+
Optional[int],
2294+
Optional[int],
2295+
Optional[int],
2296+
]:
22702297
return (
22712298
-self.input_chunk_length,
22722299
self.output_chunk_length - 1,
@@ -2343,7 +2370,16 @@ def _model_encoder_settings(
23432370
)
23442371

23452372
@property
2346-
def extreme_lags(self):
2373+
def extreme_lags(
2374+
self,
2375+
) -> Tuple[
2376+
Optional[int],
2377+
Optional[int],
2378+
Optional[int],
2379+
Optional[int],
2380+
Optional[int],
2381+
Optional[int],
2382+
]:
23472383
return (
23482384
-self.input_chunk_length,
23492385
self.output_chunk_length - 1,
@@ -2421,7 +2457,16 @@ def _model_encoder_settings(
24212457
)
24222458

24232459
@property
2424-
def extreme_lags(self):
2460+
def extreme_lags(
2461+
self,
2462+
) -> Tuple[
2463+
Optional[int],
2464+
Optional[int],
2465+
Optional[int],
2466+
Optional[int],
2467+
Optional[int],
2468+
Optional[int],
2469+
]:
24252470
return (
24262471
-self.input_chunk_length,
24272472
self.output_chunk_length - 1,

darts/tests/models/forecasting/test_ensemble_models.py

+30
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,30 @@ def test_untrained_models(self):
5757
new_model = model_ens.untrained_model()
5858
assert not new_model.models[0]._fit_called
5959

60+
def test_extreme_lag_inference(self):
61+
ensemble = NaiveEnsembleModel([NaiveDrift()])
62+
assert ensemble.extreme_lags == (
63+
-3,
64+
-1,
65+
None,
66+
None,
67+
None,
68+
None,
69+
) # test if default is okay
70+
71+
model1 = LinearRegressionModel(
72+
lags=3, lags_past_covariates=[-3, -5], lags_future_covariates=[7, 8]
73+
)
74+
model2 = LinearRegressionModel(
75+
lags=5, lags_past_covariates=6, lags_future_covariates=[6, 9]
76+
)
77+
78+
ensemble = NaiveEnsembleModel(
79+
[model1, model2]
80+
) # test if infers extreme lags is okay
81+
expected = (-5, 0, -6, -1, 6, 9)
82+
assert expected == ensemble.extreme_lags
83+
6084
def test_input_models_local_models(self):
6185
with self.assertRaises(ValueError):
6286
NaiveEnsembleModel([])
@@ -78,6 +102,12 @@ def test_call_predict_local_models(self):
78102
pred1 = naive_ensemble.predict(5)
79103
assert self.series1.components == pred1.components
80104

105+
def test_call_backtest_naive_ensemble_local_models(self):
106+
ensemble = NaiveEnsembleModel([NaiveSeasonal(5), Theta(2, 5)])
107+
ensemble.fit(self.series1)
108+
assert ensemble.extreme_lags == (-10, 0, None, None, None, None)
109+
ensemble.backtest(self.series1)
110+
81111
def test_predict_ensemble_local_models(self):
82112
naive = NaiveSeasonal(K=5)
83113
theta = Theta()

0 commit comments

Comments
 (0)