Skip to content

Commit 31159b3

Browse files
authored
Add dedicated LocalForecastingModel and renaming some model classes (#1327)
1 parent 74bafa1 commit 31159b3

20 files changed

+132
-71
lines changed

darts/dataprocessing/encoders/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
# Time Axis Encoders
1+
"""
2+
Time Axis Encoders
3+
------------------
4+
"""
5+
26
from .encoders import (
37
FutureCallableIndexEncoder,
48
FutureCyclicEncoder,

darts/models/forecasting/arima.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
from darts.logging import get_logger
1919
from darts.models.forecasting.forecasting_model import (
20-
TransferableDualCovariatesForecastingModel,
20+
TransferableFutureCovariatesLocalForecastingModel,
2121
)
2222
from darts.timeseries import TimeSeries
2323

2424
logger = get_logger(__name__)
2525

2626

27-
class ARIMA(TransferableDualCovariatesForecastingModel):
27+
class ARIMA(TransferableFutureCovariatesLocalForecastingModel):
2828
def __init__(
2929
self,
3030
p: int = 12,

darts/models/forecasting/auto_arima.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from pmdarima import AutoARIMA as PmdAutoARIMA
99

1010
from darts.logging import get_logger, raise_if
11-
from darts.models.forecasting.forecasting_model import DualCovariatesForecastingModel
11+
from darts.models.forecasting.forecasting_model import (
12+
FutureCovariatesLocalForecastingModel,
13+
)
1214
from darts.timeseries import TimeSeries
1315

1416
logger = get_logger(__name__)
1517

1618

17-
class AutoARIMA(DualCovariatesForecastingModel):
19+
class AutoARIMA(FutureCovariatesLocalForecastingModel):
1820
def __init__(self, *autoarima_args, **autoarima_kwargs):
1921
"""Auto-ARIMA
2022

darts/models/forecasting/baselines.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
from darts.logging import get_logger, raise_if_not
1313
from darts.models.forecasting.ensemble_model import EnsembleModel
1414
from darts.models.forecasting.forecasting_model import (
15-
ForecastingModel,
1615
GlobalForecastingModel,
16+
LocalForecastingModel,
1717
)
1818
from darts.timeseries import TimeSeries
1919

2020
logger = get_logger(__name__)
2121

2222

23-
class NaiveMean(ForecastingModel):
23+
class NaiveMean(LocalForecastingModel):
2424
def __init__(self):
2525
"""Naive Mean Model
2626
@@ -44,7 +44,7 @@ def predict(self, n: int, num_samples: int = 1):
4444
return self._build_forecast_series(forecast)
4545

4646

47-
class NaiveSeasonal(ForecastingModel):
47+
class NaiveSeasonal(LocalForecastingModel):
4848
def __init__(self, K: int = 1):
4949
"""Naive Seasonal Model
5050
@@ -84,7 +84,7 @@ def predict(self, n: int, num_samples: int = 1):
8484
return self._build_forecast_series(forecast)
8585

8686

87-
class NaiveDrift(ForecastingModel):
87+
class NaiveDrift(LocalForecastingModel):
8888
def __init__(self):
8989
"""Naive Drift Model
9090
@@ -117,7 +117,7 @@ def predict(self, n: int, num_samples: int = 1):
117117

118118
class NaiveEnsembleModel(EnsembleModel):
119119
def __init__(
120-
self, models: Union[List[ForecastingModel], List[GlobalForecastingModel]]
120+
self, models: Union[List[LocalForecastingModel], List[GlobalForecastingModel]]
121121
):
122122
"""Naive combination model
123123

darts/models/forecasting/croston.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from statsforecast.models import CrostonClassic, CrostonOptimized, CrostonSBA
1010

1111
from darts.logging import raise_if, raise_if_not
12-
from darts.models.forecasting.forecasting_model import DualCovariatesForecastingModel
12+
from darts.models.forecasting.forecasting_model import (
13+
FutureCovariatesLocalForecastingModel,
14+
)
1315
from darts.timeseries import TimeSeries
1416

1517

16-
class Croston(DualCovariatesForecastingModel):
18+
class Croston(FutureCovariatesLocalForecastingModel):
1719
def __init__(
1820
self, version: str = "classic", alpha_d: float = None, alpha_p: float = None
1921
):

darts/models/forecasting/ensemble_model.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from darts.logging import get_logger, raise_if, raise_if_not
1010
from darts.models.forecasting.forecasting_model import (
11-
ForecastingModel,
1211
GlobalForecastingModel,
12+
LocalForecastingModel,
1313
)
1414
from darts.timeseries import TimeSeries
1515

@@ -29,7 +29,7 @@ class EnsembleModel(GlobalForecastingModel):
2929
"""
3030

3131
def __init__(
32-
self, models: Union[List[ForecastingModel], List[GlobalForecastingModel]]
32+
self, models: Union[List[LocalForecastingModel], List[GlobalForecastingModel]]
3333
):
3434
raise_if_not(
3535
isinstance(models, list) and models,
@@ -38,17 +38,15 @@ def __init__(
3838
)
3939

4040
is_local_ensemble = all(
41-
isinstance(model, ForecastingModel)
42-
and not isinstance(model, GlobalForecastingModel)
43-
for model in models
41+
isinstance(model, LocalForecastingModel) for model in models
4442
)
4543
self.is_global_ensemble = all(
4644
isinstance(model, GlobalForecastingModel) for model in models
4745
)
4846

4947
raise_if_not(
5048
is_local_ensemble or self.is_global_ensemble,
51-
"All models must either be GlobalForecastingModel instances, or none of them should be.",
49+
"All models must be of the same type: either GlobalForecastingModel, or LocalForecastingModel.",
5250
logger,
5351
)
5452

@@ -76,12 +74,12 @@ def fit(
7674
"""
7775
raise_if(
7876
not self.is_global_ensemble and not isinstance(series, TimeSeries),
79-
"The models are not GlobalForecastingModel's and do not support training on multiple series.",
77+
"The models are of type LocalForecastingModel, which does not support training on multiple series.",
8078
logger,
8179
)
8280
raise_if(
8381
not self.is_global_ensemble and past_covariates is not None,
84-
"The models are not GlobalForecastingModel's and do not support past covariates.",
82+
"The models are of type LocalForecastingModel, which does not support past covariates.",
8583
logger,
8684
)
8785

darts/models/forecasting/exponential_smoothing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
import statsmodels.tsa.holtwinters as hw
1010

1111
from darts.logging import get_logger
12-
from darts.models.forecasting.forecasting_model import ForecastingModel
12+
from darts.models.forecasting.forecasting_model import LocalForecastingModel
1313
from darts.timeseries import TimeSeries
1414
from darts.utils.utils import ModelMode, SeasonalityMode
1515

1616
logger = get_logger(__name__)
1717

1818

19-
class ExponentialSmoothing(ForecastingModel):
19+
class ExponentialSmoothing(LocalForecastingModel):
2020
def __init__(
2121
self,
2222
trend: Optional[ModelMode] = ModelMode.ADDITIVE,

darts/models/forecasting/fft.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from statsmodels.tsa.stattools import acf
1111

1212
from darts.logging import get_logger
13-
from darts.models.forecasting.forecasting_model import ForecastingModel
13+
from darts.models.forecasting.forecasting_model import LocalForecastingModel
1414
from darts.timeseries import TimeSeries
1515
from darts.utils.missing_values import fill_missing_values
1616

@@ -210,7 +210,7 @@ def _crop_to_match_seasons(
210210
return series
211211

212212

213-
class FFT(ForecastingModel):
213+
class FFT(LocalForecastingModel):
214214
def __init__(
215215
self,
216216
nr_freqs_to_keep: Optional[int] = 10,

darts/models/forecasting/forecasting_model.py

+45-15
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def fit(self, series: TimeSeries) -> "ForecastingModel":
133133
self
134134
Fitted model.
135135
"""
136-
if not isinstance(self, DualCovariatesForecastingModel):
136+
if not isinstance(self, FutureCovariatesLocalForecastingModel):
137137
series._assert_univariate()
138138
raise_if_not(
139139
len(series) >= self.min_train_series_length,
@@ -1068,6 +1068,20 @@ def load(path: Union[str, BinaryIO]) -> "ForecastingModel":
10681068
return model
10691069

10701070

1071+
class LocalForecastingModel(ForecastingModel, ABC):
1072+
"""The base class for "local" forecasting models, handling only single univariate time series.
1073+
1074+
Local Forecasting Models (LFM) are models that can be trained on a single univariate target series only. In Darts,
1075+
most models in this category tend to be simpler statistical models (such as ETS or FFT). LFMs usually train on
1076+
the entire target series supplied when calling :func:`fit()` at once. They can also predict in one go with
1077+
:func:`predict()` for any number of predictions `n` after the end of the training series.
1078+
1079+
All implementations must implement the `_fit()` and `_predict()` methods.
1080+
"""
1081+
1082+
pass
1083+
1084+
10711085
class GlobalForecastingModel(ForecastingModel, ABC):
10721086
"""The base class for "global" forecasting models, handling several time series and optional covariates.
10731087
@@ -1080,7 +1094,7 @@ class GlobalForecastingModel(ForecastingModel, ABC):
10801094
The name "global" stems from the fact that a training set of a forecasting model of this class is not constrained
10811095
to a temporally contiguous, "local", time series.
10821096
1083-
All implementations have to implement the :func:`fit()` and :func:`predict()` methods defined below.
1097+
All implementations must implement the :func:`fit()` and :func:`predict()` methods.
10841098
The :func:`fit()` method is meant to train the model on one or several training time series, along with optional
10851099
covariates.
10861100
@@ -1371,12 +1385,18 @@ def _get_encoders_n(self, n) -> int:
13711385
return n
13721386

13731387

1374-
class DualCovariatesForecastingModel(ForecastingModel, ABC):
1375-
"""The base class for the forecasting models that are not global, but support future covariates.
1376-
Among other things, it lets Darts forecasting models wrap around statsmodels models
1377-
having a `future_covariates` parameter, which corresponds to future-known covariates.
1388+
class FutureCovariatesLocalForecastingModel(LocalForecastingModel, ABC):
1389+
"""The base class for future covariates "local" forecasting models, handling single uni- or multivariate target
1390+
and optional future covariates time series.
13781391
1379-
All implementations have to implement the `_fit()` and `_predict()` methods defined below.
1392+
Future Covariates Local Forecasting Models (FC-LFM) are models that can be trained on a single uni- or multivariate
1393+
target and optional future covariates series. In Darts, most models in this category tend to be simpler statistical
1394+
models (such as ARIMA). FC-LFMs usually train on the entire target and future covariates series supplied when
1395+
calling :func:`fit()` at once. They can also predict in one go with :func:`predict()` for any number of predictions
1396+
`n` after the end of the training series. When using future covariates, the values for the future `n` prediction
1397+
steps must be given in the covariate series.
1398+
1399+
All implementations must implement the :func:`_fit()` and :func:`_predict()` methods.
13801400
"""
13811401

13821402
_expect_covariate = False
@@ -1535,12 +1555,22 @@ def _predict_wrapper(
15351555
)
15361556

15371557

1538-
class TransferableDualCovariatesForecastingModel(DualCovariatesForecastingModel, ABC):
1539-
"""The base class for the forecasting models that are not global, but support future covariates, and can
1540-
additionally be applied to new data unrelated to the original series used for fitting the model. Currently,
1541-
all the derived classes wrap statsmodels models.
1558+
class TransferableFutureCovariatesLocalForecastingModel(
1559+
FutureCovariatesLocalForecastingModel, ABC
1560+
):
1561+
"""The base class for transferable future covariates "local" forecasting models, handling single uni- or
1562+
multivariate target and optional future covariates time series. Additionally, at prediction time, it can be
1563+
applied to new data unrelated to the original series used for fitting the model.
1564+
1565+
Transferable Future Covariates Local Forecasting Models (TFC-LFM) are models that can be trained on a single uni-
1566+
or multivariate target and optional future covariates series. Additionally, at prediction time, it can be applied
1567+
to new data unrelated to the original series used for fitting the model. Currently in Darts, all models in this
1568+
category wrap to statsmodel models such as VARIMA. TFC-LFMs usually train on the entire target and future covariates
1569+
series supplied when calling :func:`fit()` at once. They can also predict in one go with :func:`predict()`
1570+
for any number of predictions `n` after the end of the training series. When using future covariates, the values
1571+
for the future `n` prediction steps must be given in the covariate series.
15421572
1543-
All implementations have to implement the `_fit()`, `_predict()` methods.
1573+
All implementations must implement the :func:`_fit()` and :func:`_predict()` methods.
15441574
"""
15451575

15461576
def predict(
@@ -1610,8 +1640,8 @@ def predict(
16101640
series
16111641
)
16121642

1613-
# DualCovariatesForecastingModel performs some checks on self.training_series. We temporary replace that with
1614-
# the new ts
1643+
# FutureCovariatesLocalForecastingModel performs some checks on self.training_series. We temporary replace
1644+
# that with the new ts
16151645
if series is not None:
16161646
self._orig_training_series = self.training_series
16171647
self.training_series = series
@@ -1641,7 +1671,7 @@ def _predict(
16411671
num_samples: int = 1,
16421672
) -> TimeSeries:
16431673
"""Forecasts values for a certain number of time steps after the end of the series.
1644-
TransferableDualCovariatesForecastingModel must implement the predict logic in this method.
1674+
TransferableFutureCovariatesLocalForecastingModel must implement the predict logic in this method.
16451675
"""
16461676
pass
16471677

darts/models/forecasting/kalman_forecaster.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
from darts.logging import get_logger
1919
from darts.models.filtering.kalman_filter import KalmanFilter
20-
from darts.models.forecasting.forecasting_model import DualCovariatesForecastingModel
20+
from darts.models.forecasting.forecasting_model import (
21+
FutureCovariatesLocalForecastingModel,
22+
)
2123
from darts.timeseries import TimeSeries
2224

2325
logger = get_logger(__name__)
2426

2527

26-
class KalmanForecaster(DualCovariatesForecastingModel):
28+
class KalmanForecaster(FutureCovariatesLocalForecastingModel):
2729
def __init__(self, dim_x: int = 1, kf: Optional[Kalman] = None):
2830
"""Kalman filter Forecaster
2931

darts/models/forecasting/prophet_model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
import prophet
1313

1414
from darts.logging import execute_and_suppress_output, get_logger, raise_if
15-
from darts.models.forecasting.forecasting_model import DualCovariatesForecastingModel
15+
from darts.models.forecasting.forecasting_model import (
16+
FutureCovariatesLocalForecastingModel,
17+
)
1618
from darts.timeseries import TimeSeries
1719

1820
logger = get_logger(__name__)
1921
logger.level = logging.WARNING # set to warning to suppress prophet logs
2022

2123

22-
class Prophet(DualCovariatesForecastingModel):
24+
class Prophet(FutureCovariatesLocalForecastingModel):
2325
def __init__(
2426
self,
2527
add_seasonalities: Optional[Union[dict, List[dict]]] = None,

darts/models/forecasting/regression_ensemble_model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from darts.logging import get_logger, raise_if, raise_if_not
1010
from darts.models.forecasting.ensemble_model import EnsembleModel
1111
from darts.models.forecasting.forecasting_model import (
12-
ForecastingModel,
1312
GlobalForecastingModel,
13+
LocalForecastingModel,
1414
)
1515
from darts.models.forecasting.linear_regression_model import LinearRegressionModel
1616
from darts.models.forecasting.regression_model import RegressionModel
@@ -22,7 +22,9 @@
2222
class RegressionEnsembleModel(EnsembleModel):
2323
def __init__(
2424
self,
25-
forecasting_models: Union[List[ForecastingModel], List[GlobalForecastingModel]],
25+
forecasting_models: Union[
26+
List[LocalForecastingModel], List[GlobalForecastingModel]
27+
],
2628
regression_train_n_points: int,
2729
regression_model=None,
2830
):

darts/models/forecasting/sf_auto_arima.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
from statsforecast.models import AutoARIMA as SFAutoARIMA
1010

1111
from darts import TimeSeries
12-
from darts.models.forecasting.forecasting_model import DualCovariatesForecastingModel
12+
from darts.models.forecasting.forecasting_model import (
13+
FutureCovariatesLocalForecastingModel,
14+
)
1315

1416

15-
class StatsForecastAutoARIMA(DualCovariatesForecastingModel):
17+
class StatsForecastAutoARIMA(FutureCovariatesLocalForecastingModel):
1618
def __init__(self, *autoarima_args, **autoarima_kwargs):
1719
"""Auto-ARIMA based on `Statsforecasts package
1820
<https://github.com/Nixtla/statsforecast>`_.

darts/models/forecasting/sf_ets.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from statsforecast.models import ETS
99

1010
from darts import TimeSeries
11-
from darts.models.forecasting.forecasting_model import DualCovariatesForecastingModel
11+
from darts.models.forecasting.forecasting_model import (
12+
FutureCovariatesLocalForecastingModel,
13+
)
1214

1315

14-
class StatsForecastETS(DualCovariatesForecastingModel):
16+
class StatsForecastETS(FutureCovariatesLocalForecastingModel):
1517
def __init__(self, *ets_args, **ets_kwargs):
1618
"""ETS based on `Statsforecasts package
1719
<https://github.com/Nixtla/statsforecast>`_.

0 commit comments

Comments
 (0)