Skip to content

Commit 16f3a9f

Browse files
dennisbaderhrzn
andauthored
fix historical forecasts retraining of TFMs (#1465)
* fix historical forecasts retraining of TFMs * adapt historical_forecasts to not change underlying model object * fix failing tests * remove path update in quickstart notebook Co-authored-by: Julien Herzen <[email protected]>
1 parent ce45c15 commit 16f3a9f

File tree

4 files changed

+273
-168
lines changed

4 files changed

+273
-168
lines changed

darts/models/forecasting/forecasting_model.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -691,17 +691,16 @@ def historical_forecasts(
691691
retrain_func = _retrain_wrapper(
692692
lambda counter: counter % int(retrain) == 0 if retrain else False
693693
)
694-
695694
elif isinstance(retrain, Callable):
696695
retrain_func = _retrain_wrapper(retrain)
697-
698696
else:
699697
raise_log(
700698
ValueError(
701699
"`retrain` argument must be either `bool`, positive `int` or `Callable` (returning `bool`)"
702700
),
703701
logger,
704702
)
703+
705704
retrain_func_signature = tuple(
706705
inspect.signature(retrain_func).parameters.keys()
707706
)
@@ -728,7 +727,6 @@ def historical_forecasts(
728727

729728
forecasts_list = []
730729
for idx, series_ in enumerate(outer_iterator):
731-
732730
past_covariates_ = past_covariates[idx] if past_covariates else None
733731
future_covariates_ = future_covariates[idx] if future_covariates else None
734732

@@ -765,15 +763,12 @@ def historical_forecasts(
765763

766764
# prepare the start parameter -> pd.Timestamp
767765
if start is not None:
768-
769766
historical_forecasts_time_index = drop_before_index(
770767
historical_forecasts_time_index,
771768
series_.get_timestamp_at_point(start),
772769
)
773-
774770
else:
775771
if (retrain is not False) or (not self._fit_called):
776-
777772
if train_length:
778773
historical_forecasts_time_index = drop_before_index(
779774
historical_forecasts_time_index,
@@ -804,9 +799,9 @@ def historical_forecasts(
804799
(not self._fit_called)
805800
and (retrain is False)
806801
and (not train_length),
807-
" The model has not been fitted yet, and `start` and train_length are not specified. "
808-
" The model is not retraining during the historical forecasts. Hence the "
809-
"the first and only training would be done on 2 samples.",
802+
"The model has not been fitted yet, and `start` and train_length are not specified. "
803+
"The model is not retraining during the historical forecasts. Hence the "
804+
"first and only training would be done on 2 samples.",
810805
logger,
811806
)
812807

@@ -837,7 +832,6 @@ def historical_forecasts(
837832

838833
# iterate and forecast
839834
for _counter, pred_time in enumerate(iterator):
840-
841835
# build the training series
842836
if min_timestamp > series_.time_index[0]:
843837
train_series = series_.drop_before(
@@ -866,13 +860,17 @@ def historical_forecasts(
866860
if future_covariates_
867861
else None,
868862
):
869-
self._fit_wrapper(
863+
# avoid fitting the same model multiple times
864+
model = self.untrained_model()
865+
model._fit_wrapper(
870866
series=train_series,
871867
past_covariates=past_covariates_,
872868
future_covariates=future_covariates_,
873869
)
870+
else:
871+
model = self
874872

875-
forecast = self._predict_wrapper(
873+
forecast = model._predict_wrapper(
876874
n=forecast_horizon,
877875
series=train_series,
878876
past_covariates=past_covariates_,
@@ -901,7 +899,6 @@ def historical_forecasts(
901899
hierarchy=series_.hierarchy,
902900
)
903901
)
904-
905902
else:
906903
forecasts_list.append(forecasts)
907904

@@ -1526,7 +1523,7 @@ def _extract_model_creation_params(self):
15261523
return model_params
15271524

15281525
def untrained_model(self):
1529-
return self.__class__(**self.model_params)
1526+
return self.__class__(**copy.deepcopy(self.model_params))
15301527

15311528
@property
15321529
def model_params(self) -> dict:

darts/tests/models/forecasting/test_backtesting.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import pandas as pd
7+
import pytest
78

89
from darts import TimeSeries
910
from darts.logging import get_logger
@@ -219,12 +220,28 @@ def test_backtest_forecasting(self):
219220
self.assertEqual(pred.end_time(), linear_series.end_time())
220221

221222
# multivariate model + multivariate series
222-
with self.assertRaises(ValueError):
223+
# historical forecasts doesn't overwrite model object -> we can use different input dimensions
224+
tcn_model.backtest(
225+
linear_series_multi,
226+
start=pd.Timestamp("20000125"),
227+
forecast_horizon=3,
228+
verbose=False,
229+
retrain=False,
230+
)
231+
232+
# univariate model
233+
tcn_model = TCNModel(
234+
input_chunk_length=12, output_chunk_length=1, batch_size=1, n_epochs=1
235+
)
236+
tcn_model.fit(linear_series, verbose=False)
237+
# univariate fitted model + multivariate series
238+
with pytest.raises(ValueError):
223239
tcn_model.backtest(
224240
linear_series_multi,
225241
start=pd.Timestamp("20000125"),
226242
forecast_horizon=3,
227243
verbose=False,
244+
retrain=False,
228245
)
229246

230247
tcn_model = TCNModel(

darts/tests/models/forecasting/test_ensemble_models.py

+7
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ def test_untrained_models(self):
5050
with self.assertRaises(ValueError):
5151
NaiveEnsembleModel([model])
5252

53+
# an untrained ensemble model should also give untrained underlying models
54+
model_ens = NaiveEnsembleModel([NaiveDrift()])
55+
model_ens.fit(self.series1)
56+
assert model_ens.models[0]._fit_called
57+
new_model = model_ens.untrained_model()
58+
assert not new_model.models[0]._fit_called
59+
5360
def test_input_models_local_models(self):
5461
with self.assertRaises(ValueError):
5562
NaiveEnsembleModel([])

0 commit comments

Comments
 (0)