Skip to content

Commit fca3993

Browse files
authored
add feature projection for past covariates to TiDEModel (#1993)
1 parent a9b6fbc commit fca3993

File tree

3 files changed

+167
-45
lines changed

3 files changed

+167
-45
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1919
- Added short examples in the docstring of all the models, including covariates usage and some model-specific parameters. [#1956](https://github.com/unit8co/darts/pull/1956) by [Antoine Madrona](https://github.com/madtoinou).
2020
- All `RegressionModel`s now support component/column-specific lags for target, past, and future covariates series. [#1962](https://github.com/unit8co/darts/pull/1962) by [Antoine Madrona](https://github.com/madtoinou).
2121
- Added method `TimeSeries.cumsum()` to get the cumulative sum of the time series along the time axis. [#1988](https://github.com/unit8co/darts/pull/1988) by [Eliot Zubkoff](https://github.com/Eliotdoesprogramming).
22+
- 🔴 Added past covariates feature projection to `TiDEModel` with parameter `temporal_width_past` following the advice of the model architect. Parameter `temporal_width` was renamed to `temporal_width_future`. Additionally, added the option to bypass the feature projection with `temporal_width_past/future=0`. [#1993](https://github.com/unit8co/darts/pull/1993) by [Dennis Bader](https://github.com/dennisbader).
2223

2324
**Fixed**
2425
- Fixed a bug in `TimeSeries.from_dataframe()` when using a pandas.DataFrame with `df.columns.name != None`. [#1938](https://github.com/unit8co/darts/pull/1938) by [Antoine Madrona](https://github.com/madtoinou).

darts/models/forecasting/tide_model.py

+107-43
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
import torch.nn as nn
1010

11-
from darts.logging import get_logger
11+
from darts.logging import get_logger, raise_log
1212
from darts.models.forecasting.pl_forecasting_module import (
1313
PLMixedCovariatesModule,
1414
io_processor,
@@ -77,7 +77,8 @@ def __init__(
7777
decoder_output_dim: int,
7878
hidden_size: int,
7979
temporal_decoder_hidden: int,
80-
temporal_width: int,
80+
temporal_width_past: int,
81+
temporal_width_future: int,
8182
use_layer_norm: bool,
8283
dropout: float,
8384
**kwargs,
@@ -106,7 +107,9 @@ def __init__(
106107
The width of the hidden layers in the encoder/decoder Residual Blocks.
107108
temporal_decoder_hidden
108109
The width of the hidden layers in the temporal decoder.
109-
temporal_width
110+
temporal_width_past
111+
The width of the past covariate embedding space.
112+
temporal_width_future
110113
The width of the future covariate embedding space.
111114
use_layer_norm
112115
Whether to use layer normalization in the Residual Blocks.
@@ -131,6 +134,7 @@ def __init__(
131134

132135
self.input_dim = input_dim
133136
self.output_dim = output_dim
137+
self.past_cov_dim = input_dim - output_dim - future_cov_dim
134138
self.future_cov_dim = future_cov_dim
135139
self.static_cov_dim = static_cov_dim
136140
self.nr_params = nr_params
@@ -141,28 +145,52 @@ def __init__(
141145
self.temporal_decoder_hidden = temporal_decoder_hidden
142146
self.use_layer_norm = use_layer_norm
143147
self.dropout = dropout
144-
self.temporal_width = temporal_width
148+
self.temporal_width_past = temporal_width_past
149+
self.temporal_width_future = temporal_width_future
150+
151+
# past covariates handling: either feature projection, raw features, or no features
152+
self.past_cov_projection = None
153+
if self.past_cov_dim and temporal_width_past:
154+
# residual block for past covariates feature projection
155+
self.past_cov_projection = _ResidualBlock(
156+
input_dim=self.past_cov_dim,
157+
output_dim=temporal_width_past,
158+
hidden_size=hidden_size,
159+
use_layer_norm=use_layer_norm,
160+
dropout=dropout,
161+
)
162+
past_covariates_flat_dim = self.input_chunk_length * temporal_width_past
163+
elif self.past_cov_dim:
164+
# skip projection and use raw features
165+
past_covariates_flat_dim = self.input_chunk_length * self.past_cov_dim
166+
else:
167+
past_covariates_flat_dim = 0
145168

146-
# residual block for input feature projection
147-
# this is only needed when covariates are used
148-
if future_cov_dim:
149-
self.feature_projection = _ResidualBlock(
169+
# future covariates handling: either feature projection, raw features, or no features
170+
self.future_cov_projection = None
171+
if future_cov_dim and self.temporal_width_future:
172+
# residual block for future covariates feature projection
173+
self.future_cov_projection = _ResidualBlock(
150174
input_dim=future_cov_dim,
151-
output_dim=temporal_width,
175+
output_dim=temporal_width_future,
152176
hidden_size=hidden_size,
153177
use_layer_norm=use_layer_norm,
154178
dropout=dropout,
155179
)
180+
historical_future_covariates_flat_dim = (
181+
self.input_chunk_length + self.output_chunk_length
182+
) * temporal_width_future
183+
elif future_cov_dim:
184+
# skip projection and use raw features
185+
historical_future_covariates_flat_dim = (
186+
self.input_chunk_length + self.output_chunk_length
187+
) * future_cov_dim
156188
else:
157-
self.feature_projection = None
189+
historical_future_covariates_flat_dim = 0
158190

159-
# original paper doesn't specify how to use past covariates
160-
# we assume that they pass them raw to the encoder
161-
historical_future_covariates_flat_dim = (
162-
self.input_chunk_length + self.output_chunk_length
163-
) * (self.temporal_width if future_cov_dim > 0 else 0)
164191
encoder_dim = (
165-
self.input_chunk_length * (input_dim - future_cov_dim)
192+
self.input_chunk_length * output_dim
193+
+ past_covariates_flat_dim
166194
+ historical_future_covariates_flat_dim
167195
+ static_cov_dim
168196
)
@@ -210,9 +238,14 @@ def __init__(
210238
),
211239
)
212240

241+
decoder_input_dim = decoder_output_dim * self.nr_params
242+
if temporal_width_future and future_cov_dim:
243+
decoder_input_dim += temporal_width_future
244+
elif future_cov_dim:
245+
decoder_input_dim += future_cov_dim
246+
213247
self.temporal_decoder = _ResidualBlock(
214-
input_dim=decoder_output_dim * self.nr_params
215-
+ (temporal_width if future_cov_dim > 0 else 0),
248+
input_dim=decoder_input_dim,
216249
output_dim=output_dim * self.nr_params,
217250
hidden_size=temporal_decoder_hidden,
218251
use_layer_norm=use_layer_norm,
@@ -246,44 +279,49 @@ def forward(
246279

247280
x_lookback = x[:, :, : self.output_dim]
248281

249-
# future covariates need to be extracted from x and stacked with historical future covariates
250-
if self.future_cov_dim > 0:
251-
x_dynamic_covariates = torch.cat(
282+
# future covariates: feature projection or raw features
283+
# historical future covariates need to be extracted from x and stacked with part of future covariates
284+
if self.future_cov_dim:
285+
x_dynamic_future_covariates = torch.cat(
252286
[
253-
x_future_covariates,
254287
x[
255288
:,
256289
:,
257290
None if self.future_cov_dim == 0 else -self.future_cov_dim :,
258291
],
292+
x_future_covariates,
259293
],
260294
dim=1,
261295
)
262-
263-
# project input features across all input time steps
264-
x_dynamic_covariates_proj = self.feature_projection(x_dynamic_covariates)
265-
296+
if self.temporal_width_future:
297+
# project input features across all input and output time steps
298+
x_dynamic_future_covariates = self.future_cov_projection(
299+
x_dynamic_future_covariates
300+
)
266301
else:
267-
x_dynamic_covariates = None
268-
x_dynamic_covariates_proj = None
302+
x_dynamic_future_covariates = None
269303

270-
# extract past covariates, if they exist
271-
if self.input_dim - self.output_dim - self.future_cov_dim > 0:
272-
x_past_covariates = x[
304+
# past covariates: feature projection or raw features
305+
# the past covariates are embedded in `x`
306+
if self.past_cov_dim:
307+
x_dynamic_past_covariates = x[
273308
:,
274309
:,
275-
self.output_dim : None
276-
if self.future_cov_dim == 0
277-
else -self.future_cov_dim :,
310+
self.output_dim : self.output_dim + self.past_cov_dim,
278311
]
312+
if self.temporal_width_past:
313+
# project input features across all input time steps
314+
x_dynamic_past_covariates = self.past_cov_projection(
315+
x_dynamic_past_covariates
316+
)
279317
else:
280-
x_past_covariates = None
318+
x_dynamic_past_covariates = None
281319

282320
# setup input to encoder
283321
encoded = [
284322
x_lookback,
285-
x_past_covariates,
286-
x_dynamic_covariates_proj,
323+
x_dynamic_past_covariates,
324+
x_dynamic_future_covariates,
287325
x_static_covariates,
288326
]
289327
encoded = [t.flatten(start_dim=1) for t in encoded if t is not None]
@@ -299,7 +337,7 @@ def forward(
299337
# stack and temporally decode with future covariate last output steps
300338
temporal_decoder_input = [
301339
decoded,
302-
x_dynamic_covariates_proj[:, -self.output_chunk_length :, :]
340+
x_dynamic_future_covariates[:, -self.output_chunk_length :, :]
303341
if self.future_cov_dim > 0
304342
else None,
305343
]
@@ -331,7 +369,8 @@ def __init__(
331369
num_decoder_layers: int = 1,
332370
decoder_output_dim: int = 16,
333371
hidden_size: int = 128,
334-
temporal_width: int = 4,
372+
temporal_width_past: int = 4,
373+
temporal_width_future: int = 4,
335374
temporal_decoder_hidden: int = 32,
336375
use_layer_norm: bool = False,
337376
dropout: float = 0.1,
@@ -369,8 +408,12 @@ def __init__(
369408
The dimensionality of the output of the decoder.
370409
hidden_size
371410
The width of the layers in the residual blocks of the encoder and decoder.
372-
temporal_width
373-
The width of the layers in the future covariate projection residual block.
411+
temporal_width_past
412+
The width of the layers in the past covariate projection residual block. If `0`,
413+
will bypass feature projection and use the raw feature data.
414+
temporal_width_future
415+
The width of the layers in the future covariate projection residual block. If `0`,
416+
will bypass feature projection and use the raw feature data.
374417
temporal_decoder_hidden
375418
The width of the layers in the temporal decoder.
376419
use_layer_norm
@@ -550,6 +593,13 @@ def encode_year(idx):
550593
`TiDE example notebook <https://unit8co.github.io/darts/examples/18-TiDE-examples.html>`_ presents
551594
techniques that can be used to improve the forecasts quality compared to this simple usage example.
552595
"""
596+
if temporal_width_past < 0 or temporal_width_future < 0:
597+
raise_log(
598+
ValueError(
599+
"`temporal_width_past` and `temporal_width_future` must be >= 0."
600+
),
601+
logger=logger,
602+
)
553603
super().__init__(**self._extract_torch_model_params(**self.model_params))
554604

555605
# extract pytorch lightning module kwargs
@@ -559,7 +609,8 @@ def encode_year(idx):
559609
self.num_decoder_layers = num_decoder_layers
560610
self.decoder_output_dim = decoder_output_dim
561611
self.hidden_size = hidden_size
562-
self.temporal_width = temporal_width
612+
self.temporal_width_past = temporal_width_past
613+
self.temporal_width_future = temporal_width_future
563614
self.temporal_decoder_hidden = temporal_decoder_hidden
564615

565616
self._considers_static_covariates = use_static_covariates
@@ -603,6 +654,18 @@ def _create_model(
603654

604655
nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters
605656

657+
past_cov_dim = input_dim - output_dim - future_cov_dim
658+
if past_cov_dim and self.temporal_width_past >= past_cov_dim:
659+
logger.warning(
660+
f"number of `past_covariates` features is <= `temporal_width_past`, leading to feature expansion."
661+
f"number of covariates: {past_cov_dim}, `temporal_width_past={self.temporal_width_past}`."
662+
)
663+
if future_cov_dim and self.temporal_width_future >= future_cov_dim:
664+
logger.warning(
665+
f"number of `future_covariates` features is <= `temporal_width_future`, leading to feature expansion."
666+
f"number of covariates: {future_cov_dim}, `temporal_width_future={self.temporal_width_future}`."
667+
)
668+
606669
return _TideModule(
607670
input_dim=input_dim,
608671
output_dim=output_dim,
@@ -613,7 +676,8 @@ def _create_model(
613676
num_decoder_layers=self.num_decoder_layers,
614677
decoder_output_dim=self.decoder_output_dim,
615678
hidden_size=self.hidden_size,
616-
temporal_width=self.temporal_width,
679+
temporal_width_past=self.temporal_width_past,
680+
temporal_width_future=self.temporal_width_future,
617681
temporal_decoder_hidden=self.temporal_decoder_hidden,
618682
use_layer_norm=self.use_layer_norm,
619683
dropout=self.dropout,

darts/tests/models/forecasting/test_tide_model.py

+59-2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,54 @@ def test_future_and_past_covariate_handling(self):
118118
)
119119
model.fit(ts_time_index, verbose=False, epochs=1)
120120

121+
model = TiDEModel(
122+
input_chunk_length=1,
123+
output_chunk_length=1,
124+
add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
125+
**tfm_kwargs
126+
)
127+
model.fit(ts_time_index, verbose=False, epochs=1)
128+
129+
@pytest.mark.parametrize("temporal_widths", [(-1, 1), (1, -1)])
130+
def test_failing_future_and_past_temporal_widths(self, temporal_widths):
131+
# invalid temporal widths
132+
with pytest.raises(ValueError):
133+
TiDEModel(
134+
input_chunk_length=1,
135+
output_chunk_length=1,
136+
temporal_width_past=temporal_widths[0],
137+
temporal_width_future=temporal_widths[1],
138+
**tfm_kwargs
139+
)
140+
141+
@pytest.mark.parametrize(
142+
"temporal_widths",
143+
[
144+
(2, 2), # feature projection to same amount of features
145+
(1, 2), # past: feature reduction, future: same amount of features
146+
(2, 1), # past: same amount of features, future: feature reduction
147+
(3, 3), # feature expansion
148+
(0, 2), # bypass past feature projection
149+
(2, 0), # bypass future feature projection
150+
(0, 0), # bypass all feature projection
151+
],
152+
)
153+
def test_future_and_past_temporal_widths(self, temporal_widths):
154+
ts_time_index = tg.sine_timeseries(length=2, freq="h")
155+
156+
# feature projection to 2 features (same amount as input features)
157+
model = TiDEModel(
158+
input_chunk_length=1,
159+
output_chunk_length=1,
160+
temporal_width_past=temporal_widths[0],
161+
temporal_width_future=temporal_widths[1],
162+
add_encoders={"cyclic": {"future": "hour", "past": "hour"}},
163+
**tfm_kwargs
164+
)
165+
model.fit(ts_time_index, verbose=False, epochs=1)
166+
assert model.model.temporal_width_past == temporal_widths[0]
167+
assert model.model.temporal_width_future == temporal_widths[1]
168+
121169
def test_past_covariate_handling(self):
122170
ts_time_index = tg.sine_timeseries(length=2, freq="h")
123171

@@ -142,7 +190,12 @@ def test_future_and_past_covariate_as_timeseries_handling(self):
142190
use_reversible_instance_norm=enable_rin,
143191
**tfm_kwargs
144192
)
145-
model.fit(ts_time_index, ts_time_index, verbose=False, epochs=1)
193+
model.fit(
194+
ts_time_index,
195+
past_covariates=ts_time_index,
196+
verbose=False,
197+
epochs=1,
198+
)
146199

147200
# test with past_covariates and future_covariates timeseries
148201
model = TiDEModel(
@@ -153,7 +206,11 @@ def test_future_and_past_covariate_as_timeseries_handling(self):
153206
**tfm_kwargs
154207
)
155208
model.fit(
156-
ts_time_index, ts_time_index, ts_time_index, verbose=False, epochs=1
209+
ts_time_index,
210+
past_covariates=ts_time_index,
211+
future_covariates=ts_time_index,
212+
verbose=False,
213+
epochs=1,
157214
)
158215

159216
def test_static_covariates_support(self):

0 commit comments

Comments
 (0)