Skip to content

Commit 5ceac68

Browse files
piaz97brunneduhrzn
authored
Apply statsmodels-based ARIMA/VARIMA to new TS (unit8co#1036)
* Added new base class and adjusted ARIMA + tests * [ARIMA] Added docstrings and tests * Adapted VARIMA as well + tests * Keeping training state after forecasting new TS, refactoring * Updated docstrings * Fixed some formatting and added one last test * Restored deleted check * Fixed a logic issue with current training_series param * Cleaning * Update darts/models/forecasting/forecasting_model.py Co-authored-by: Dustin Brunner <[email protected]> * Update darts/models/forecasting/forecasting_model.py Co-authored-by: Dustin Brunner <[email protected]> * Added VARIMA prob forecasting support * Added missing build.gradle * Apply suggestions from code review (copy=False) Co-authored-by: Julien Herzen <[email protected]> * Replaced ignore_axes -> ignore_axis * Added backtest with retrain=False support * Small fixes - Removed a print statement - Moved an error message that was misleading - Change type of staVARMA to ndarray, otherwise the result will be a multiindexed pandas df instead of the expected ndarray * Added some missing values(copy=False) Co-authored-by: Dustin Brunner <[email protected]> Co-authored-by: Julien Herzen <[email protected]>
1 parent 7ad25b4 commit 5ceac68

File tree

4 files changed

+390
-28
lines changed

4 files changed

+390
-28
lines changed

darts/models/forecasting/arima.py

+40-8
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
from statsmodels.tsa.arima.model import ARIMA as staARIMA
1717

1818
from darts.logging import get_logger
19-
from darts.models.forecasting.forecasting_model import DualCovariatesForecastingModel
19+
from darts.models.forecasting.forecasting_model import (
20+
TransferableDualCovariatesForecastingModel,
21+
)
2022
from darts.timeseries import TimeSeries
2123

2224
logger = get_logger(__name__)
2325

2426

25-
class ARIMA(DualCovariatesForecastingModel):
27+
class ARIMA(TransferableDualCovariatesForecastingModel):
2628
def __init__(
2729
self,
2830
p: int = 12,
@@ -66,11 +68,14 @@ def __str__(self):
6668
return f"SARIMA{self.order}x{self.seasonal_order}"
6769

6870
def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
69-
7071
super()._fit(series, future_covariates)
72+
73+
# storing to restore the statsmodels model results object
74+
self.training_historic_future_covariates = future_covariates
75+
7176
m = staARIMA(
72-
self.training_series.values(),
73-
exog=future_covariates.values() if future_covariates else None,
77+
series.values(copy=False),
78+
exog=future_covariates.values(copy=False) if future_covariates else None,
7479
order=self.order,
7580
seasonal_order=self.seasonal_order,
7681
trend=self.trend,
@@ -82,6 +87,8 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non
8287
def _predict(
8388
self,
8489
n: int,
90+
series: Optional[TimeSeries] = None,
91+
historic_future_covariates: Optional[TimeSeries] = None,
8592
future_covariates: Optional[TimeSeries] = None,
8693
num_samples: int = 1,
8794
) -> TimeSeries:
@@ -93,18 +100,43 @@ def _predict(
93100
"your model."
94101
)
95102

96-
super()._predict(n, future_covariates, num_samples)
103+
super()._predict(
104+
n, series, historic_future_covariates, future_covariates, num_samples
105+
)
106+
107+
# updating statsmodels results object state with the new ts and covariates
108+
if series is not None:
109+
self.model = self.model.apply(
110+
series.values(copy=False),
111+
exog=historic_future_covariates.values(copy=False)
112+
if historic_future_covariates
113+
else None,
114+
)
97115

98116
if num_samples == 1:
99117
forecast = self.model.forecast(
100-
steps=n, exog=future_covariates.values() if future_covariates else None
118+
steps=n,
119+
exog=future_covariates.values(copy=False)
120+
if future_covariates
121+
else None,
101122
)
102123
else:
103124
forecast = self.model.simulate(
104125
nsimulations=n,
105126
repetitions=num_samples,
106127
initial_state=self.model.states.predicted[-1, :],
107-
exog=future_covariates.values() if future_covariates else None,
128+
exog=future_covariates.values(copy=False)
129+
if future_covariates
130+
else None,
131+
)
132+
133+
# restoring statsmodels results object state
134+
if series is not None:
135+
self.model = self.model.apply(
136+
self._orig_training_series.values(copy=False),
137+
exog=self.training_historic_future_covariates.values(copy=False)
138+
if self.training_historic_future_covariates
139+
else None,
108140
)
109141

110142
return self._build_forecast_series(forecast)

darts/models/forecasting/forecasting_model.py

+140-5
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ class DualCovariatesForecastingModel(ForecastingModel, ABC):
10831083
Among other things, it lets Darts forecasting models wrap around statsmodels models
10841084
having a `future_covariates` parameter, which corresponds to future-known covariates.
10851085
1086-
All implementations have to implement the `fit()` and `predict()` methods defined below.
1086+
All implementations have to implement the `_fit()` and `_predict()` methods defined below.
10871087
"""
10881088

10891089
_expect_covariate = False
@@ -1137,6 +1137,7 @@ def predict(
11371137
n: int,
11381138
future_covariates: Optional[TimeSeries] = None,
11391139
num_samples: int = 1,
1140+
**kwargs,
11401141
) -> TimeSeries:
11411142
"""Forecasts values for `n` time steps after the end of the training series.
11421143
@@ -1159,8 +1160,7 @@ def predict(
11591160
TimeSeries, a single time series containing the `n` next points after then end of the training series.
11601161
"""
11611162

1162-
if future_covariates is None:
1163-
super().predict(n, num_samples)
1163+
super().predict(n, num_samples)
11641164

11651165
if self._expect_covariate and future_covariates is None:
11661166
raise_log(
@@ -1170,6 +1170,12 @@ def predict(
11701170
)
11711171
)
11721172

1173+
raise_if(
1174+
not self._expect_covariate and future_covariates is not None,
1175+
"The model has been trained without `future_covariates` variable, but the "
1176+
"`future_covariates` parameter provided to `predict()` is not None.",
1177+
)
1178+
11731179
if future_covariates is not None:
11741180
start = self.training_series.end_time() + self.training_series.freq
11751181

@@ -1194,13 +1200,13 @@ def predict(
11941200
]
11951201

11961202
raise_if_not(
1197-
len(future_covariates) == n and self._expect_covariate,
1203+
len(future_covariates) == n,
11981204
invalid_time_span_error,
11991205
logger,
12001206
)
12011207

12021208
return self._predict(
1203-
n, future_covariates=future_covariates, num_samples=num_samples
1209+
n, future_covariates=future_covariates, num_samples=num_samples, **kwargs
12041210
)
12051211

12061212
@abstractmethod
@@ -1234,3 +1240,132 @@ def _predict_wrapper(
12341240
return self.predict(
12351241
n, future_covariates=future_covariates, num_samples=num_samples
12361242
)
1243+
1244+
1245+
class TransferableDualCovariatesForecastingModel(DualCovariatesForecastingModel, ABC):
1246+
"""The base class for the forecasting models that are not global, but support future covariates, and can
1247+
additionally be applied to new data unrelated to the original series used for fitting the model. Currently,
1248+
all the derived classes wrap statsmodels models.
1249+
1250+
All implementations have to implement the `_fit()`, `_predict()` methods.
1251+
"""
1252+
1253+
def predict(
1254+
self,
1255+
n: int,
1256+
series: Optional[TimeSeries] = None,
1257+
future_covariates: Optional[TimeSeries] = None,
1258+
num_samples: int = 1,
1259+
**kwargs,
1260+
) -> TimeSeries:
1261+
"""If the `series` parameter is not set, forecasts values for `n` time steps after the end of the training
1262+
series. If some future covariates were specified during the training, they must also be specified here.
1263+
1264+
If the `series` parameter is set, forecasts values for `n` time steps after the end of the new target
1265+
series. If some future covariates were specified during the training, they must also be specified here.
1266+
1267+
Parameters
1268+
----------
1269+
n
1270+
Forecast horizon - the number of time steps after the end of the series for which to produce predictions.
1271+
series
1272+
Optionally, a new target series whose future values will be predicted. Defaults to `None`, meaning that the
1273+
model will forecast the future value of the training series.
1274+
future_covariates
1275+
The time series of future-known covariates which can be fed as input to the model. It must correspond to
1276+
the covariate time series that has been used with the :func:`fit()` method for training.
1277+
1278+
If `series` is not set, it must contain at least the next `n` time steps/indices after the end of the
1279+
training target series. If `series` is set, it must contain at least the time steps/indices corresponding
1280+
to the new target series (historic future covariates), plus the next `n` time steps/indices after the end.
1281+
num_samples
1282+
Number of times a prediction is sampled from a probabilistic model. Should be left set to 1
1283+
for deterministic models.
1284+
1285+
Returns
1286+
-------
1287+
TimeSeries, a single time series containing the `n` next points after then end of the training series.
1288+
"""
1289+
1290+
if self._expect_covariate and future_covariates is None:
1291+
raise_log(
1292+
ValueError(
1293+
"The model has been trained with `future_covariates` variable. Some matching "
1294+
"`future_covariates` variables have to be provided to `predict()`."
1295+
)
1296+
)
1297+
1298+
historic_future_covariates = None
1299+
1300+
if series is not None and future_covariates:
1301+
raise_if_not(
1302+
future_covariates.start_time() <= series.start_time()
1303+
and future_covariates.end_time() >= series.end_time() + n * series.freq,
1304+
"The provided `future_covariates` related to the new target series must contain at least the same time"
1305+
"steps/indices as the target `series` + `n`.",
1306+
logger,
1307+
)
1308+
# splitting the future covariates
1309+
(
1310+
historic_future_covariates,
1311+
future_covariates,
1312+
) = future_covariates.split_after(series.end_time())
1313+
1314+
# in case future covariate have more values on the left end side that we don't need
1315+
if not series.has_same_time_as(historic_future_covariates):
1316+
historic_future_covariates = historic_future_covariates.slice_intersect(
1317+
series
1318+
)
1319+
1320+
# DualCovariatesForecastingModel performs some checks on self.training_series. We temporary replace that with
1321+
# the new ts
1322+
if series is not None:
1323+
self._orig_training_series = self.training_series
1324+
self.training_series = series
1325+
1326+
result = super().predict(
1327+
n=n,
1328+
series=series,
1329+
historic_future_covariates=historic_future_covariates,
1330+
future_covariates=future_covariates,
1331+
num_samples=num_samples,
1332+
**kwargs,
1333+
)
1334+
1335+
# restoring the original training ts
1336+
if series is not None:
1337+
self.training_series = self._orig_training_series
1338+
1339+
return result
1340+
1341+
@abstractmethod
1342+
def _predict(
1343+
self,
1344+
n: int,
1345+
series: Optional[TimeSeries] = None,
1346+
historic_future_covariates: Optional[TimeSeries] = None,
1347+
future_covariates: Optional[TimeSeries] = None,
1348+
num_samples: int = 1,
1349+
) -> TimeSeries:
1350+
"""Forecasts values for a certain number of time steps after the end of the series.
1351+
TransferableDualCovariatesForecastingModel must implement the predict logic in this method.
1352+
"""
1353+
pass
1354+
1355+
def _predict_wrapper(
1356+
self,
1357+
n: int,
1358+
series: TimeSeries,
1359+
past_covariates: Optional[TimeSeries],
1360+
future_covariates: Optional[TimeSeries],
1361+
num_samples: int,
1362+
) -> TimeSeries:
1363+
return self.predict(
1364+
n=n,
1365+
series=series,
1366+
future_covariates=future_covariates,
1367+
num_samples=num_samples,
1368+
)
1369+
1370+
def _supports_non_retrainable_historical_forecasts(self) -> bool:
1371+
return True

0 commit comments

Comments
 (0)