Skip to content

Commit a93296b

Browse files
dennisbaderhrzn
andauthored
Feat/encoders extension (#1093)
* added encoders to regression models * added unit tests for encoders in regression model * remove torch flavor checks for local forecasting model tests * reset ptl trainer when loading torch models * reduced estimators for RandomForest test case Co-authored-by: Julien Herzen <[email protected]>
1 parent df4e2d7 commit a93296b

10 files changed

+1317
-951
lines changed

darts/models/forecasting/catboost_model.py

+22
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
lags_past_covariates: Union[int, List[int]] = None,
2424
lags_future_covariates: Union[Tuple[int, int], List[int]] = None,
2525
output_chunk_length: int = 1,
26+
add_encoders: Optional[dict] = None,
2627
likelihood: str = None,
2728
quantiles: List = None,
2829
random_state: Optional[int] = None,
@@ -48,6 +49,26 @@ def __init__(
4849
Number of time steps predicted at once by the internal regression model. Does not have to equal the forecast
4950
horizon `n` used in `predict()`. However, setting `output_chunk_length` equal to the forecast horizon may
5051
be useful if the covariates don't extend far enough into the future.
52+
add_encoders
53+
A large number of past and future covariates can be automatically generated with `add_encoders`.
54+
This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that
55+
will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to
56+
transform the generated covariates. This happens all under one hood and only needs to be specified at
57+
model creation.
58+
Read :meth:`SequentialEncoder <darts.utils.data.encoders.SequentialEncoder>` to find out more about
59+
``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features:
60+
61+
.. highlight:: python
62+
.. code-block:: python
63+
64+
add_encoders={
65+
'cyclic': {'future': ['month']},
66+
'datetime_attribute': {'future': ['hour', 'dayofweek']},
67+
'position': {'past': ['absolute'], 'future': ['relative']},
68+
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
69+
'transformer': Scaler()
70+
}
71+
..
5172
likelihood
5273
Can be set to 'quantile', 'poisson' or 'gaussian'. If set, the model will be probabilistic,
5374
allowing sampling at prediction time. When set to 'gaussian', the model will use CatBoost's
@@ -96,6 +117,7 @@ def __init__(
96117
lags_past_covariates=lags_past_covariates,
97118
lags_future_covariates=lags_future_covariates,
98119
output_chunk_length=output_chunk_length,
120+
add_encoders=add_encoders,
99121
model=CatBoostRegressor(**kwargs),
100122
)
101123

darts/models/forecasting/ensemble_model.py

+5
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ def _stack_ts_multiseq(self, predictions_list):
114114
# stacks multiple sequences of timeseries elementwise
115115
return [self._stack_ts_seq(ts_list) for ts_list in zip(*predictions_list)]
116116

117+
def _model_encoder_settings(self):
118+
raise NotImplementedError(
119+
"Encoders are not supported by EnsembleModels. Instead add encoder to the underlying `models`."
120+
)
121+
117122
def _make_multiple_predictions(
118123
self,
119124
n: int,

darts/models/forecasting/forecasting_model.py

+35
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
_parallel_apply,
3333
_with_sanity_checks,
3434
)
35+
from darts.utils.data.encoders import SequentialEncoder
3536
from darts.utils.timeseries_generation import (
3637
_build_forecast_series,
3738
_generate_new_dates,
@@ -932,6 +933,13 @@ class GlobalForecastingModel(ForecastingModel, ABC):
932933
_expect_past_covariates, _expect_future_covariates = False, False
933934
past_covariate_series, future_covariate_series = None, None
934935

936+
def __init__(self, add_encoders: Optional[dict] = None):
937+
super().__init__()
938+
939+
# by default models do not use encoders
940+
self.add_encoders = add_encoders
941+
self.encoders: Optional[SequentialEncoder] = None
942+
935943
@abstractmethod
936944
def fit(
937945
self,
@@ -1084,6 +1092,33 @@ def _supports_non_retrainable_historical_forecasts(self) -> bool:
10841092
"""GlobalForecastingModel supports historical forecasts without retraining the model"""
10851093
return True
10861094

1095+
@property
1096+
@abstractmethod
1097+
def _model_encoder_settings(self) -> Tuple[int, int, bool, bool]:
1098+
"""Abstract property that returns model specific encoder settings that are used to initialize the encoders.
1099+
1100+
Must return Tuple (input_chunk_length, output_chunk_length, takes_past_covariates, takes_future_covariates)
1101+
"""
1102+
pass
1103+
1104+
def initialize_encoders(self) -> SequentialEncoder:
1105+
"""instantiates the SequentialEncoder object based on self._model_encoder_settings and parameter
1106+
``add_encoders`` used at model creation"""
1107+
(
1108+
input_chunk_length,
1109+
output_chunk_length,
1110+
takes_past_covariates,
1111+
takes_future_covariates,
1112+
) = self._model_encoder_settings
1113+
1114+
return SequentialEncoder(
1115+
add_encoders=self.add_encoders,
1116+
input_chunk_length=input_chunk_length,
1117+
output_chunk_length=output_chunk_length,
1118+
takes_past_covariates=takes_past_covariates,
1119+
takes_future_covariates=takes_future_covariates,
1120+
)
1121+
10871122

10881123
class DualCovariatesForecastingModel(ForecastingModel, ABC):
10891124
"""The base class for the forecasting models that are not global, but support future covariates.

darts/models/forecasting/gradient_boosted_model.py

+22
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
lags_past_covariates: Union[int, List[int]] = None,
2828
lags_future_covariates: Union[Tuple[int, int], List[int]] = None,
2929
output_chunk_length: int = 1,
30+
add_encoders: Optional[dict] = None,
3031
likelihood: str = None,
3132
quantiles: List[float] = None,
3233
random_state: Optional[int] = None,
@@ -52,6 +53,26 @@ def __init__(
5253
Number of time steps predicted at once by the internal regression model. Does not have to equal the forecast
5354
horizon `n` used in `predict()`. However, setting `output_chunk_length` equal to the forecast horizon may
5455
be useful if the covariates don't extend far enough into the future.
56+
add_encoders
57+
A large number of past and future covariates can be automatically generated with `add_encoders`.
58+
This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that
59+
will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to
60+
transform the generated covariates. This happens all under one hood and only needs to be specified at
61+
model creation.
62+
Read :meth:`SequentialEncoder <darts.utils.data.encoders.SequentialEncoder>` to find out more about
63+
``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features:
64+
65+
.. highlight:: python
66+
.. code-block:: python
67+
68+
add_encoders={
69+
'cyclic': {'future': ['month']},
70+
'datetime_attribute': {'future': ['hour', 'dayofweek']},
71+
'position': {'past': ['absolute'], 'future': ['relative']},
72+
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
73+
'transformer': Scaler()
74+
}
75+
..
5576
likelihood
5677
Can be set to `quantile` or `poisson`. If set, the model will be probabilistic, allowing sampling at
5778
prediction time.
@@ -87,6 +108,7 @@ def __init__(
87108
lags_past_covariates=lags_past_covariates,
88109
lags_future_covariates=lags_future_covariates,
89110
output_chunk_length=output_chunk_length,
111+
add_encoders=add_encoders,
90112
model=lgb.LGBMRegressor(**kwargs),
91113
)
92114

darts/models/forecasting/linear_regression_model.py

+22
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
lags_past_covariates: Union[int, List[int]] = None,
2626
lags_future_covariates: Union[Tuple[int, int], List[int]] = None,
2727
output_chunk_length: int = 1,
28+
add_encoders: Optional[dict] = None,
2829
likelihood: str = None,
2930
quantiles: List[float] = None,
3031
random_state: Optional[int] = None,
@@ -50,6 +51,26 @@ def __init__(
5051
Number of time steps predicted at once by the internal regression model. Does not have to equal the forecast
5152
horizon `n` used in `predict()`. However, setting `output_chunk_length` equal to the forecast horizon may
5253
be useful if the covariates don't extend far enough into the future.
54+
add_encoders
55+
A large number of past and future covariates can be automatically generated with `add_encoders`.
56+
This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that
57+
will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to
58+
transform the generated covariates. This happens all under one hood and only needs to be specified at
59+
model creation.
60+
Read :meth:`SequentialEncoder <darts.utils.data.encoders.SequentialEncoder>` to find out more about
61+
``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features:
62+
63+
.. highlight:: python
64+
.. code-block:: python
65+
66+
add_encoders={
67+
'cyclic': {'future': ['month']},
68+
'datetime_attribute': {'future': ['hour', 'dayofweek']},
69+
'position': {'past': ['absolute'], 'future': ['relative']},
70+
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
71+
'transformer': Scaler()
72+
}
73+
..
5374
likelihood
5475
Can be set to `quantile` or `poisson`. If set, the model will be probabilistic, allowing sampling at
5576
prediction time. If set to `quantile`, the `sklearn.linear_model.QuantileRegressor` is used. Similarly, if
@@ -94,6 +115,7 @@ def __init__(
94115
lags_past_covariates=lags_past_covariates,
95116
lags_future_covariates=lags_future_covariates,
96117
output_chunk_length=output_chunk_length,
118+
add_encoders=add_encoders,
97119
model=model,
98120
)
99121

darts/models/forecasting/random_forest.py

+22
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
lags_past_covariates: Union[int, List[int]] = None,
3232
lags_future_covariates: Union[Tuple[int, int], List[int]] = None,
3333
output_chunk_length: int = 1,
34+
add_encoders: Optional[dict] = None,
3435
n_estimators: Optional[int] = 100,
3536
max_depth: Optional[int] = None,
3637
**kwargs,
@@ -55,6 +56,26 @@ def __init__(
5556
Number of time steps predicted at once by the internal regression model. Does not have to equal the forecast
5657
horizon `n` used in `predict()`. However, setting `output_chunk_length` equal to the forecast horizon may
5758
be useful if the covariates don't extend far enough into the future.
59+
add_encoders
60+
A large number of past and future covariates can be automatically generated with `add_encoders`.
61+
This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that
62+
will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to
63+
transform the generated covariates. This happens all under one hood and only needs to be specified at
64+
model creation.
65+
Read :meth:`SequentialEncoder <darts.utils.data.encoders.SequentialEncoder>` to find out more about
66+
``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features:
67+
68+
.. highlight:: python
69+
.. code-block:: python
70+
71+
add_encoders={
72+
'cyclic': {'future': ['month']},
73+
'datetime_attribute': {'future': ['hour', 'dayofweek']},
74+
'position': {'past': ['absolute'], 'future': ['relative']},
75+
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
76+
'transformer': Scaler()
77+
}
78+
..
5879
n_estimators : int
5980
The number of trees in the forest.
6081
max_depth : int
@@ -74,6 +95,7 @@ def __init__(
7495
lags_past_covariates=lags_past_covariates,
7596
lags_future_covariates=lags_future_covariates,
7697
output_chunk_length=output_chunk_length,
98+
add_encoders=add_encoders,
7799
model=RandomForestRegressor(**kwargs),
78100
)
79101

darts/models/forecasting/regression_model.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
lags_past_covariates: Union[int, List[int]] = None,
4949
lags_future_covariates: Union[Tuple[int, int], List[int]] = None,
5050
output_chunk_length: int = 1,
51+
add_encoders: Optional[dict] = None,
5152
model=None,
5253
):
5354
"""Regression Model
@@ -71,14 +72,34 @@ def __init__(
7172
Number of time steps predicted at once by the internal regression model. Does not have to equal the forecast
7273
horizon `n` used in `predict()`. However, setting `output_chunk_length` equal to the forecast horizon may
7374
be useful if the covariates don't extend far enough into the future.
75+
add_encoders
76+
A large number of past and future covariates can be automatically generated with `add_encoders`.
77+
This can be done by adding multiple pre-defined index encoders and/or custom user-made functions that
78+
will be used as index encoders. Additionally, a transformer such as Darts' :class:`Scaler` can be added to
79+
transform the generated covariates. This happens all under one hood and only needs to be specified at
80+
model creation.
81+
Read :meth:`SequentialEncoder <darts.utils.data.encoders.SequentialEncoder>` to find out more about
82+
``add_encoders``. Default: ``None``. An example showing some of ``add_encoders`` features:
83+
84+
.. highlight:: python
85+
.. code-block:: python
86+
87+
add_encoders={
88+
'cyclic': {'future': ['month']},
89+
'datetime_attribute': {'future': ['hour', 'dayofweek']},
90+
'position': {'past': ['absolute'], 'future': ['relative']},
91+
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
92+
'transformer': Scaler()
93+
}
94+
..
7495
model
7596
Scikit-learn-like model with ``fit()`` and ``predict()`` methods. Also possible to use model that doesn't
7697
support multi-output regression for multivariate timeseries, in which case one regressor
7798
will be used per component in the multivariate series.
7899
If None, defaults to: ``sklearn.linear_model.LinearRegression(n_jobs=-1)``.
79100
"""
80101

81-
super().__init__()
102+
super().__init__(add_encoders=add_encoders)
82103

83104
self.model = model
84105
self.lags = {}
@@ -200,6 +221,46 @@ def __init__(
200221
)
201222
self.output_chunk_length = output_chunk_length
202223

224+
@property
225+
def _model_encoder_settings(self) -> Tuple[int, int, bool, bool]:
226+
lags_covariates = {
227+
lag for key in ["past", "future"] for lag in self.lags.get(key, [])
228+
}
229+
if lags_covariates:
230+
# for lags < 0 we need to take `n` steps backwards from past and/or historic future covariates
231+
# for minimum lag = -1 -> steps_back_inclusive = 1
232+
# inclusive means n steps back including the end of the target series
233+
n_steps_back_inclusive = abs(min(min(lags_covariates), 0))
234+
# for lags >= 0 we need to take `n` steps ahead from future covariates
235+
# for maximum lag = 0 -> output_chunk_length = 1
236+
# exclusive means n steps ahead after the last step of the target series
237+
n_steps_ahead_exclusive = max(max(lags_covariates), 0) + 1
238+
takes_past_covariates = "past" in self.lags
239+
takes_future_covariates = "future" in self.lags
240+
else:
241+
n_steps_back_inclusive = 0
242+
n_steps_ahead_exclusive = 0
243+
takes_past_covariates = False
244+
takes_future_covariates = False
245+
return (
246+
n_steps_back_inclusive,
247+
n_steps_ahead_exclusive,
248+
takes_past_covariates,
249+
takes_future_covariates,
250+
)
251+
252+
def _get_encoders_n(self, n):
253+
"""Returns the `n` encoder prediction steps specific to RegressionModels.
254+
This will generate slightly more past covariates than the minimum requirement when using past and future
255+
covariate lags simultaneously. This is because encoders were written for TorchForecastingModels where we only
256+
needed `n` future covariates. For RegressionModel we need `n + max_future_lag`
257+
"""
258+
_, n_steps_ahead, _, takes_future_covariates = self._model_encoder_settings
259+
if not takes_future_covariates:
260+
return n
261+
else:
262+
return n + (n_steps_ahead - 1)
263+
203264
@property
204265
def min_train_series_length(self) -> int:
205266
return max(
@@ -319,6 +380,7 @@ def _fit_model(
319380
Function that fit the model. Deriving classes can override this method for adding additional parameters (e.g.,
320381
adding validation data), keeping the sanity checks on series performed by fit().
321382
"""
383+
322384
training_samples, training_labels = self._create_lagged_data(
323385
target_series, past_covariates, future_covariates, max_samples_per_ts
324386
)
@@ -361,6 +423,15 @@ def fit(
361423
**kwargs
362424
Additional keyword arguments passed to the `fit` method of the model.
363425
"""
426+
427+
self.encoders = self.initialize_encoders()
428+
if self.encoders.encoding_available:
429+
past_covariates, future_covariates = self.encoders.encode_train(
430+
target=series,
431+
past_covariate=past_covariates,
432+
future_covariate=future_covariates,
433+
)
434+
364435
super().fit(
365436
series=series,
366437
past_covariates=past_covariates,
@@ -477,6 +548,14 @@ def predict(
477548
logger,
478549
)
479550

551+
if self.encoders.encoding_available:
552+
past_covariates, future_covariates = self.encoders.encode_inference(
553+
n=self._get_encoders_n(n),
554+
target=series,
555+
past_covariate=past_covariates,
556+
future_covariate=future_covariates,
557+
)
558+
480559
super().predict(n, series, past_covariates, future_covariates, num_samples)
481560

482561
if series is None:

0 commit comments

Comments
 (0)