Skip to content

Commit 72ea486

Browse files
Fix/dlinear and nlinear use_static_cov. with multivariate series (#2070)
* feat: added test to check that use_static_covariates covers all possible static covariates representations * fix: properly account for the two possible static covariates representation in multivariates series * fix: typo in the warning message * feat: added type hint, reordered docstring to match argument order * feat: added type hint, reordered docstring to match argument order (dlinear) * feat: updated changelog --------- Co-authored-by: Dennis Bader <[email protected]>
1 parent f196665 commit 72ea486

File tree

5 files changed

+77
-37
lines changed

5 files changed

+77
-37
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
3131
- Fixed a bug when loading a `TorchForecastingModel` that was trained with a precision other than `float64`. [#2046](https://github.com/unit8co/darts/pull/2046) by [Freddie Hsin-Fu Huang](https://github.com/Hsinfu).
3232
- Fixed broken links in the `Transfer learning` example notebook with publicly hosted version of the three datasets. [#2067](https://github.com/unit8co/darts/pull/2067) by [Antoine Madrona](https://github.com/madtoinou).
3333
- Fixed a bug when using `NLinearModel` on multivariate series with covariates and `normalize=True`. [#2072](https://github.com/unit8co/darts/pull/2072) by [Antoine Madrona](https://github.com/madtoinou).
34+
- Fixed a bug when using `DLinearModel` and `NLinearModel` on multivariate series with "components-shared" static covariates and `use_static_covariates=True`. [#2070](https://github.com/unit8co/darts/pull/2070) by [Antoine Madrona](https://github.com/madtoinou).
3435

3536
### For developers of the library:
3637

darts/models/forecasting/dlinear.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@ class _DLinearModule(PLMixedCovariatesModule):
6868

6969
def __init__(
7070
self,
71-
input_dim,
72-
output_dim,
73-
future_cov_dim,
74-
static_cov_dim,
75-
nr_params,
76-
shared_weights,
77-
kernel_size,
78-
const_init,
71+
input_dim: int,
72+
output_dim: int,
73+
future_cov_dim: int,
74+
static_cov_dim: int,
75+
nr_params: int,
76+
shared_weights: bool,
77+
kernel_size: int,
78+
const_init: bool,
7979
**kwargs,
8080
):
8181
"""PyTorch module implementing the DLinear architecture.
@@ -89,7 +89,7 @@ def __init__(
8989
future_cov_dim
9090
Number of components in the future covariates
9191
static_cov_dim
92-
Dimensionality of the static covariates
92+
Dimensionality of the static covariates (either component-specific or shared)
9393
nr_params
9494
The number of parameters of the likelihood (or 1 if no likelihood is used).
9595
shared_weights
@@ -113,8 +113,6 @@ def __init__(
113113
Tensor containing the output of the NBEATS module.
114114
"""
115115

116-
# TODO: could we support future covariates with a simple extension?
117-
118116
super().__init__(**kwargs)
119117
self.input_dim = input_dim
120118
self.output_dim = output_dim
@@ -142,9 +140,6 @@ def _create_linear_layer(in_dim, out_dim):
142140
layer_in_dim = self.input_chunk_length * self.input_dim
143141
layer_out_dim = self.output_chunk_length * self.output_dim * self.nr_params
144142

145-
# for static cov, we take the number of components of the target, times static cov dim
146-
layer_in_dim_static_cov = self.output_dim * self.static_cov_dim
147-
148143
self.linear_seasonal = _create_linear_layer(layer_in_dim, layer_out_dim)
149144
self.linear_trend = _create_linear_layer(layer_in_dim, layer_out_dim)
150145

@@ -155,7 +150,7 @@ def _create_linear_layer(in_dim, out_dim):
155150
)
156151
if self.static_cov_dim != 0:
157152
self.linear_static_cov = _create_linear_layer(
158-
layer_in_dim_static_cov, layer_out_dim
153+
self.static_cov_dim, layer_out_dim
159154
)
160155

161156
@io_processor
@@ -477,8 +472,8 @@ def _create_model(
477472
raise_if(
478473
self.shared_weights
479474
and (train_sample[1] is not None or train_sample[2] is not None),
480-
"Covariates have been provided, but the model has been built with shared_weights=True."
481-
+ "Please set shared_weights=False to use covariates.",
475+
"Covariates have been provided, but the model has been built with shared_weights=True. "
476+
"Please set shared_weights=False to use covariates.",
482477
)
483478

484479
input_dim = train_sample[0].shape[1] + sum(
@@ -488,8 +483,11 @@ def _create_model(
488483
)
489484
future_cov_dim = train_sample[3].shape[1] if train_sample[3] is not None else 0
490485

491-
# dimension is (component, static_dim), we extract static_dim
492-
static_cov_dim = train_sample[4].shape[1] if train_sample[4] is not None else 0
486+
if train_sample[4] is None:
487+
static_cov_dim = 0
488+
else:
489+
# account for component-specific or shared static covariates representation
490+
static_cov_dim = train_sample[4].shape[0] * train_sample[4].shape[1]
493491

494492
output_dim = train_sample[-1].shape[1]
495493
nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters

darts/models/forecasting/nlinear.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ class _NLinearModule(PLMixedCovariatesModule):
2323

2424
def __init__(
2525
self,
26-
input_dim,
27-
output_dim,
28-
future_cov_dim,
29-
static_cov_dim,
30-
nr_params,
31-
shared_weights,
32-
const_init,
33-
normalize,
26+
input_dim: int,
27+
output_dim: int,
28+
future_cov_dim: int,
29+
static_cov_dim: int,
30+
nr_params: int,
31+
shared_weights: bool,
32+
const_init: bool,
33+
normalize: bool,
3434
**kwargs,
3535
):
3636
"""PyTorch module implementing the N-HiTS architecture.
@@ -44,16 +44,16 @@ def __init__(
4444
future_cov_dim
4545
Number of components in the future covariates
4646
static_cov_dim
47-
Dimensionality of the static covariates
47+
Dimensionality of the static covariates (either component-specific or shared)
4848
nr_params
4949
The number of parameters of the likelihood (or 1 if no likelihood is used).
5050
shared_weights
5151
Whether to use shared weights for the components of the series.
5252
** Ignores covariates when True. **
53-
normalize
54-
Whether to apply the "normalization" described in the paper.
5553
const_init
5654
Whether to initialize the weights to 1/in_len
55+
normalize
56+
Whether to apply the "normalization" described in the paper.
5757
5858
**kwargs
5959
all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class.
@@ -94,9 +94,6 @@ def _create_linear_layer(in_dim, out_dim):
9494
layer_in_dim = self.input_chunk_length * self.input_dim
9595
layer_out_dim = self.output_chunk_length * self.output_dim * self.nr_params
9696

97-
# for static cov, we take the number of components of the target, times static cov dim
98-
layer_in_dim_static_cov = self.output_dim * self.static_cov_dim
99-
10097
self.layer = _create_linear_layer(layer_in_dim, layer_out_dim)
10198

10299
if self.future_cov_dim != 0:
@@ -106,7 +103,7 @@ def _create_linear_layer(in_dim, out_dim):
106103
)
107104
if self.static_cov_dim != 0:
108105
self.linear_static_cov = _create_linear_layer(
109-
layer_in_dim_static_cov, layer_out_dim
106+
self.static_cov_dim, layer_out_dim
110107
)
111108

112109
@io_processor
@@ -438,8 +435,11 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
438435
)
439436
future_cov_dim = train_sample[3].shape[1] if train_sample[3] is not None else 0
440437

441-
# dimension is (component, static_dim), we extract static_dim
442-
static_cov_dim = train_sample[4].shape[1] if train_sample[4] is not None else 0
438+
if train_sample[4] is None:
439+
static_cov_dim = 0
440+
else:
441+
# account for component-specific or shared static covariates representation
442+
static_cov_dim = train_sample[4].shape[0] * train_sample[4].shape[1]
443443

444444
output_dim = train_sample[-1].shape[1]
445445

darts/tests/models/forecasting/test_global_forecasting_models.py

+41
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from copy import deepcopy
3+
from itertools import product
34
from unittest.mock import ANY, patch
45

56
import numpy as np
@@ -206,6 +207,15 @@ class TestGlobalForecastingModels:
206207
target = sine_1_ts + sine_2_ts + linear_ts + sine_3_ts
207208
target_past, target_future = target.split_after(split_ratio)
208209

210+
# various ts with different static covariates representations
211+
ts_w_static_cov = tg.linear_timeseries(length=80).with_static_covariates(
212+
pd.Series([1, 2])
213+
)
214+
ts_shared_static_cov = ts_w_static_cov.stack(tg.sine_timeseries(length=80))
215+
ts_comps_static_cov = ts_shared_static_cov.with_static_covariates(
216+
pd.DataFrame([[0, 1], [2, 3]], columns=["st1", "st2"])
217+
)
218+
209219
@pytest.mark.parametrize("config", models_cls_kwargs_errs)
210220
def test_save_model_parameters(self, config):
211221
# model creation parameters were saved before. check if re-created model has same params as original
@@ -450,6 +460,37 @@ def test_future_covariates(self):
450460
with pytest.raises(ValueError):
451461
model.predict(n=161, future_covariates=self.covariates)
452462

463+
@pytest.mark.parametrize(
464+
"model_cls,ts",
465+
product(
466+
[TFTModel, DLinearModel, NLinearModel, TiDEModel],
467+
[ts_w_static_cov, ts_shared_static_cov, ts_comps_static_cov],
468+
),
469+
)
470+
def test_use_static_covariates(self, model_cls, ts):
471+
"""
472+
Check that both static covariates representations are supported (component-specific and shared)
473+
for both uni- and multivariate series when fitting the model.
474+
Also check that the static covariates are present in the forecasted series
475+
"""
476+
model = model_cls(
477+
input_chunk_length=IN_LEN,
478+
output_chunk_length=OUT_LEN,
479+
random_state=0,
480+
use_static_covariates=True,
481+
n_epochs=1,
482+
**tfm_kwargs,
483+
)
484+
# must provide mandatory future_covariates to TFTModel
485+
model.fit(
486+
series=ts,
487+
future_covariates=self.sine_1_ts
488+
if model.supports_future_covariates
489+
else None,
490+
)
491+
pred = model.predict(OUT_LEN)
492+
assert pred.static_covariates.equals(ts.static_covariates)
493+
453494
def test_batch_predictions(self):
454495
# predicting multiple time series at once needs to work for arbitrary batch sizes
455496
# univariate case

darts/tests/models/forecasting/test_torch_forecasting_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666

6767
TORCH_AVAILABLE = True
6868
except ImportError:
69-
logger.warning("Torch not available. RNN tests will be skipped.")
69+
logger.warning("Torch not available. Tests will be skipped.")
7070
TORCH_AVAILABLE = False
7171

7272
if TORCH_AVAILABLE:

0 commit comments

Comments
 (0)