Skip to content

Commit 865a6a5

Browse files
JanFidoralexcolpitts96
authored andcommitted
add already existing forecasts param (unit8co#1597)
1 parent 5dd11bb commit 865a6a5

File tree

2 files changed

+45
-21
lines changed

2 files changed

+45
-21
lines changed

darts/models/forecasting/forecasting_model.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,7 @@ def historical_forecasts(
922922
def backtest(
923923
self,
924924
series: Union[TimeSeries, Sequence[TimeSeries]],
925+
historical_forecasts: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
925926
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
926927
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
927928
num_samples: int = 1,
@@ -942,12 +943,12 @@ def backtest(
942943
"""Compute error values that the model would have produced when
943944
used on (potentially multiple) `series`.
944945
945-
It repeatedly builds a training set from the beginning of `series`. It trains the
946-
current model on the training set, emits a forecast of length equal to forecast_horizon, and then moves
947-
the end of the
948-
training set forward by `stride` time steps. A metric (given by the `metric` function) is then evaluated
949-
on the forecast and the actual values. Finally, the method returns a `reduction` (the mean by default)
950-
of all these metric scores.
946+
If `historical_forecasts` are provided, the metric (given by the `metric` function) is evaluated directly on
947+
the forecast and the actual values. Otherwise, it repeatedly builds a training set from the beginning of
948+
`series`. It trains the current model on the training set, emits a forecast of length equal to
949+
`forecast_horizon`, and then moves the end of the training set forward by `stride` time steps. The metric is
950+
then evaluated on the forecast and the actual values. Finally, the method returns a `reduction` (the mean by
951+
default) of all these metric scores.
951952
952953
By default, this method uses each historical forecast (whole) to compute error scores.
953954
If `last_points_only` is set to True, it will use only the last point of each historical
@@ -964,6 +965,11 @@ def backtest(
964965
----------
965966
series
966967
The (or a sequence of) target time series to use to successively train and evaluate the historical forecasts
968+
historical_forecasts
969+
Optionally, the (or a sequence of) historical forecasts time series to be evaluated. Corresponds to
970+
the output of :meth:`historical_forecasts() <ForecastingModel.historical_forecasts>`. If provided, will
971+
skip historical forecasting and ignore parameters `num_samples`, `train_length`, `start`,
972+
`forecast_horizon`, `stride`, `retrain`, `overlap_end`, and `last_points_only`.
967973
past_covariates
968974
Optionally, one (or a sequence of) past-observed covariate series.
969975
This applies only if the model supports past covariates.
@@ -1035,21 +1041,23 @@ def backtest(
10351041
The (sequence of) error score on a series, or list of list containing error scores for each
10361042
provided series and each sample.
10371043
"""
1038-
1039-
forecasts = self.historical_forecasts(
1040-
series=series,
1041-
past_covariates=past_covariates,
1042-
future_covariates=future_covariates,
1043-
num_samples=num_samples,
1044-
train_length=train_length,
1045-
start=start,
1046-
forecast_horizon=forecast_horizon,
1047-
stride=stride,
1048-
retrain=retrain,
1049-
overlap_end=overlap_end,
1050-
last_points_only=last_points_only,
1051-
verbose=verbose,
1052-
)
1044+
if historical_forecasts is None:
1045+
forecasts = self.historical_forecasts(
1046+
series=series,
1047+
past_covariates=past_covariates,
1048+
future_covariates=future_covariates,
1049+
num_samples=num_samples,
1050+
train_length=train_length,
1051+
start=start,
1052+
forecast_horizon=forecast_horizon,
1053+
stride=stride,
1054+
retrain=retrain,
1055+
overlap_end=overlap_end,
1056+
last_points_only=last_points_only,
1057+
verbose=verbose,
1058+
)
1059+
else:
1060+
forecasts = historical_forecasts
10531061

10541062
series = series2seq(series)
10551063
if len(series) == 1:

darts/tests/models/forecasting/test_backtesting.py

+16
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,22 @@ def test_backtest_forecasting(self):
116116
)
117117
self.assertEqual(score, 1.0)
118118

119+
# univariate model + univariate series + historical_forecasts precalculated
120+
forecasts = NaiveDrift().historical_forecasts(
121+
linear_series,
122+
start=pd.Timestamp("20000201"),
123+
forecast_horizon=3,
124+
last_points_only=False,
125+
)
126+
precalculated_forecasts_score = NaiveDrift().backtest(
127+
linear_series,
128+
historical_forecasts=forecasts,
129+
start=pd.Timestamp("20000201"),
130+
forecast_horizon=3,
131+
metric=r2_score,
132+
)
133+
self.assertEqual(score, precalculated_forecasts_score)
134+
119135
# very large train length should not affect the backtest
120136
score = NaiveDrift().backtest(
121137
linear_series,

0 commit comments

Comments
 (0)