8
8
import torch
9
9
import torch .nn as nn
10
10
11
- from darts .logging import get_logger
11
+ from darts .logging import get_logger , raise_log
12
12
from darts .models .forecasting .pl_forecasting_module import (
13
13
PLMixedCovariatesModule ,
14
14
io_processor ,
@@ -77,7 +77,8 @@ def __init__(
77
77
decoder_output_dim : int ,
78
78
hidden_size : int ,
79
79
temporal_decoder_hidden : int ,
80
- temporal_width : int ,
80
+ temporal_width_past : int ,
81
+ temporal_width_future : int ,
81
82
use_layer_norm : bool ,
82
83
dropout : float ,
83
84
** kwargs ,
@@ -106,7 +107,9 @@ def __init__(
106
107
The width of the hidden layers in the encoder/decoder Residual Blocks.
107
108
temporal_decoder_hidden
108
109
The width of the hidden layers in the temporal decoder.
109
- temporal_width
110
+ temporal_width_past
111
+ The width of the past covariate embedding space.
112
+ temporal_width_future
110
113
The width of the future covariate embedding space.
111
114
use_layer_norm
112
115
Whether to use layer normalization in the Residual Blocks.
@@ -131,6 +134,7 @@ def __init__(
131
134
132
135
self .input_dim = input_dim
133
136
self .output_dim = output_dim
137
+ self .past_cov_dim = input_dim - output_dim - future_cov_dim
134
138
self .future_cov_dim = future_cov_dim
135
139
self .static_cov_dim = static_cov_dim
136
140
self .nr_params = nr_params
@@ -141,28 +145,52 @@ def __init__(
141
145
self .temporal_decoder_hidden = temporal_decoder_hidden
142
146
self .use_layer_norm = use_layer_norm
143
147
self .dropout = dropout
144
- self .temporal_width = temporal_width
148
+ self .temporal_width_past = temporal_width_past
149
+ self .temporal_width_future = temporal_width_future
150
+
151
+ # past covariates handling: either feature projection, raw features, or no features
152
+ self .past_cov_projection = None
153
+ if self .past_cov_dim and temporal_width_past :
154
+ # residual block for past covariates feature projection
155
+ self .past_cov_projection = _ResidualBlock (
156
+ input_dim = self .past_cov_dim ,
157
+ output_dim = temporal_width_past ,
158
+ hidden_size = hidden_size ,
159
+ use_layer_norm = use_layer_norm ,
160
+ dropout = dropout ,
161
+ )
162
+ past_covariates_flat_dim = self .input_chunk_length * temporal_width_past
163
+ elif self .past_cov_dim :
164
+ # skip projection and use raw features
165
+ past_covariates_flat_dim = self .input_chunk_length * self .past_cov_dim
166
+ else :
167
+ past_covariates_flat_dim = 0
145
168
146
- # residual block for input feature projection
147
- # this is only needed when covariates are used
148
- if future_cov_dim :
149
- self .feature_projection = _ResidualBlock (
169
+ # future covariates handling: either feature projection, raw features, or no features
170
+ self .future_cov_projection = None
171
+ if future_cov_dim and self .temporal_width_future :
172
+ # residual block for future covariates feature projection
173
+ self .future_cov_projection = _ResidualBlock (
150
174
input_dim = future_cov_dim ,
151
- output_dim = temporal_width ,
175
+ output_dim = temporal_width_future ,
152
176
hidden_size = hidden_size ,
153
177
use_layer_norm = use_layer_norm ,
154
178
dropout = dropout ,
155
179
)
180
+ historical_future_covariates_flat_dim = (
181
+ self .input_chunk_length + self .output_chunk_length
182
+ ) * temporal_width_future
183
+ elif future_cov_dim :
184
+ # skip projection and use raw features
185
+ historical_future_covariates_flat_dim = (
186
+ self .input_chunk_length + self .output_chunk_length
187
+ ) * future_cov_dim
156
188
else :
157
- self . feature_projection = None
189
+ historical_future_covariates_flat_dim = 0
158
190
159
- # original paper doesn't specify how to use past covariates
160
- # we assume that they pass them raw to the encoder
161
- historical_future_covariates_flat_dim = (
162
- self .input_chunk_length + self .output_chunk_length
163
- ) * (self .temporal_width if future_cov_dim > 0 else 0 )
164
191
encoder_dim = (
165
- self .input_chunk_length * (input_dim - future_cov_dim )
192
+ self .input_chunk_length * output_dim
193
+ + past_covariates_flat_dim
166
194
+ historical_future_covariates_flat_dim
167
195
+ static_cov_dim
168
196
)
@@ -210,9 +238,14 @@ def __init__(
210
238
),
211
239
)
212
240
241
+ decoder_input_dim = decoder_output_dim * self .nr_params
242
+ if temporal_width_future and future_cov_dim :
243
+ decoder_input_dim += temporal_width_future
244
+ elif future_cov_dim :
245
+ decoder_input_dim += future_cov_dim
246
+
213
247
self .temporal_decoder = _ResidualBlock (
214
- input_dim = decoder_output_dim * self .nr_params
215
- + (temporal_width if future_cov_dim > 0 else 0 ),
248
+ input_dim = decoder_input_dim ,
216
249
output_dim = output_dim * self .nr_params ,
217
250
hidden_size = temporal_decoder_hidden ,
218
251
use_layer_norm = use_layer_norm ,
@@ -246,44 +279,49 @@ def forward(
246
279
247
280
x_lookback = x [:, :, : self .output_dim ]
248
281
249
- # future covariates need to be extracted from x and stacked with historical future covariates
250
- if self .future_cov_dim > 0 :
251
- x_dynamic_covariates = torch .cat (
282
+ # future covariates: feature projection or raw features
283
+ # historical future covariates need to be extracted from x and stacked with part of future covariates
284
+ if self .future_cov_dim :
285
+ x_dynamic_future_covariates = torch .cat (
252
286
[
253
- x_future_covariates ,
254
287
x [
255
288
:,
256
289
:,
257
290
None if self .future_cov_dim == 0 else - self .future_cov_dim :,
258
291
],
292
+ x_future_covariates ,
259
293
],
260
294
dim = 1 ,
261
295
)
262
-
263
- # project input features across all input time steps
264
- x_dynamic_covariates_proj = self .feature_projection (x_dynamic_covariates )
265
-
296
+ if self .temporal_width_future :
297
+ # project input features across all input and output time steps
298
+ x_dynamic_future_covariates = self .future_cov_projection (
299
+ x_dynamic_future_covariates
300
+ )
266
301
else :
267
- x_dynamic_covariates = None
268
- x_dynamic_covariates_proj = None
302
+ x_dynamic_future_covariates = None
269
303
270
- # extract past covariates, if they exist
271
- if self .input_dim - self .output_dim - self .future_cov_dim > 0 :
272
- x_past_covariates = x [
304
+ # past covariates: feature projection or raw features
305
+ # the past covariates are embedded in `x`
306
+ if self .past_cov_dim :
307
+ x_dynamic_past_covariates = x [
273
308
:,
274
309
:,
275
- self .output_dim : None
276
- if self .future_cov_dim == 0
277
- else - self .future_cov_dim :,
310
+ self .output_dim : self .output_dim + self .past_cov_dim ,
278
311
]
312
+ if self .temporal_width_past :
313
+ # project input features across all input time steps
314
+ x_dynamic_past_covariates = self .past_cov_projection (
315
+ x_dynamic_past_covariates
316
+ )
279
317
else :
280
- x_past_covariates = None
318
+ x_dynamic_past_covariates = None
281
319
282
320
# setup input to encoder
283
321
encoded = [
284
322
x_lookback ,
285
- x_past_covariates ,
286
- x_dynamic_covariates_proj ,
323
+ x_dynamic_past_covariates ,
324
+ x_dynamic_future_covariates ,
287
325
x_static_covariates ,
288
326
]
289
327
encoded = [t .flatten (start_dim = 1 ) for t in encoded if t is not None ]
@@ -299,7 +337,7 @@ def forward(
299
337
# stack and temporally decode with future covariate last output steps
300
338
temporal_decoder_input = [
301
339
decoded ,
302
- x_dynamic_covariates_proj [:, - self .output_chunk_length :, :]
340
+ x_dynamic_future_covariates [:, - self .output_chunk_length :, :]
303
341
if self .future_cov_dim > 0
304
342
else None ,
305
343
]
@@ -331,7 +369,8 @@ def __init__(
331
369
num_decoder_layers : int = 1 ,
332
370
decoder_output_dim : int = 16 ,
333
371
hidden_size : int = 128 ,
334
- temporal_width : int = 4 ,
372
+ temporal_width_past : int = 4 ,
373
+ temporal_width_future : int = 4 ,
335
374
temporal_decoder_hidden : int = 32 ,
336
375
use_layer_norm : bool = False ,
337
376
dropout : float = 0.1 ,
@@ -369,8 +408,12 @@ def __init__(
369
408
The dimensionality of the output of the decoder.
370
409
hidden_size
371
410
The width of the layers in the residual blocks of the encoder and decoder.
372
- temporal_width
373
- The width of the layers in the future covariate projection residual block.
411
+ temporal_width_past
412
+ The width of the layers in the past covariate projection residual block. If `0`,
413
+ will bypass feature projection and use the raw feature data.
414
+ temporal_width_future
415
+ The width of the layers in the future covariate projection residual block. If `0`,
416
+ will bypass feature projection and use the raw feature data.
374
417
temporal_decoder_hidden
375
418
The width of the layers in the temporal decoder.
376
419
use_layer_norm
@@ -550,6 +593,13 @@ def encode_year(idx):
550
593
`TiDE example notebook <https://unit8co.github.io/darts/examples/18-TiDE-examples.html>`_ presents
551
594
techniques that can be used to improve the forecasts quality compared to this simple usage example.
552
595
"""
596
+ if temporal_width_past < 0 or temporal_width_future < 0 :
597
+ raise_log (
598
+ ValueError (
599
+ "`temporal_width_past` and `temporal_width_future` must be >= 0."
600
+ ),
601
+ logger = logger ,
602
+ )
553
603
super ().__init__ (** self ._extract_torch_model_params (** self .model_params ))
554
604
555
605
# extract pytorch lightning module kwargs
@@ -559,7 +609,8 @@ def encode_year(idx):
559
609
self .num_decoder_layers = num_decoder_layers
560
610
self .decoder_output_dim = decoder_output_dim
561
611
self .hidden_size = hidden_size
562
- self .temporal_width = temporal_width
612
+ self .temporal_width_past = temporal_width_past
613
+ self .temporal_width_future = temporal_width_future
563
614
self .temporal_decoder_hidden = temporal_decoder_hidden
564
615
565
616
self ._considers_static_covariates = use_static_covariates
@@ -603,6 +654,18 @@ def _create_model(
603
654
604
655
nr_params = 1 if self .likelihood is None else self .likelihood .num_parameters
605
656
657
+ past_cov_dim = input_dim - output_dim - future_cov_dim
658
+ if past_cov_dim and self .temporal_width_past >= past_cov_dim :
659
+ logger .warning (
660
+ f"number of `past_covariates` features is <= `temporal_width_past`, leading to feature expansion."
661
+ f"number of covariates: { past_cov_dim } , `temporal_width_past={ self .temporal_width_past } `."
662
+ )
663
+ if future_cov_dim and self .temporal_width_future >= future_cov_dim :
664
+ logger .warning (
665
+ f"number of `future_covariates` features is <= `temporal_width_future`, leading to feature expansion."
666
+ f"number of covariates: { future_cov_dim } , `temporal_width_future={ self .temporal_width_future } `."
667
+ )
668
+
606
669
return _TideModule (
607
670
input_dim = input_dim ,
608
671
output_dim = output_dim ,
@@ -613,7 +676,8 @@ def _create_model(
613
676
num_decoder_layers = self .num_decoder_layers ,
614
677
decoder_output_dim = self .decoder_output_dim ,
615
678
hidden_size = self .hidden_size ,
616
- temporal_width = self .temporal_width ,
679
+ temporal_width_past = self .temporal_width_past ,
680
+ temporal_width_future = self .temporal_width_future ,
617
681
temporal_decoder_hidden = self .temporal_decoder_hidden ,
618
682
use_layer_norm = self .use_layer_norm ,
619
683
dropout = self .dropout ,
0 commit comments