Skip to content

Commit c9914e3

Browse files
author
Rijk van der Meulen
committed
2 parents 7b1b7e7 + 8a90725 commit c9914e3

File tree

9 files changed

+208
-37
lines changed

9 files changed

+208
-37
lines changed

darts/dataprocessing/dtw/dtw.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
from typing import Callable, Union
33

44
import numpy as np
5+
import pandas as pd
6+
import xarray as xr
57

68
from darts import TimeSeries
79
from darts.logging import get_logger, raise_if, raise_if_not
10+
from darts.timeseries import DIMS
811

912
from .cost_matrix import CostMatrix
1013
from .window import CRWindow, NoWindow, Window
@@ -203,25 +206,34 @@ def warped(self) -> (TimeSeries, TimeSeries):
203206
Two new TimeSeries instances of the same length, indexed by pd.RangeIndex.
204207
"""
205208

206-
series1 = self.series1
207-
series2 = self.series2
208-
209-
xa1 = series1.data_array(copy=False)
210-
xa2 = series2.data_array(copy=False)
211-
209+
xa1 = self.series1.data_array(copy=False)
210+
xa2 = self.series2.data_array(copy=False)
212211
path = self.path()
213212

214-
warped_series1 = xa1[path[:, 0]]
215-
warped_series2 = xa2[path[:, 1]]
216-
217-
time_dim1 = series1._time_dim
218-
time_dim2 = series2._time_dim
213+
values1, values2 = xa1.values[path[:, 0]], xa2.values[path[:, 1]]
214+
215+
# We set a RangeIndex for both series:
216+
warped_series1 = xr.DataArray(
217+
data=values1,
218+
dims=xa1.dims,
219+
coords={
220+
self.series1._time_dim: pd.RangeIndex(values1.shape[0]),
221+
DIMS[1]: xa1.coords[DIMS[1]],
222+
},
223+
attrs=xa1.attrs,
224+
)
219225

220-
range_index = True
226+
warped_series2 = xr.DataArray(
227+
data=values2,
228+
dims=xa2.dims,
229+
coords={
230+
self.series2._time_dim: pd.RangeIndex(values2.shape[0]),
231+
DIMS[1]: xa2.coords[DIMS[1]],
232+
},
233+
attrs=xa2.attrs,
234+
)
221235

222-
if range_index:
223-
warped_series1 = warped_series1.reset_index(dims_or_levels=time_dim1)
224-
warped_series2 = warped_series2.reset_index(dims_or_levels=time_dim2)
236+
time_dim1, time_dim2 = self.series1._time_dim, self.series2._time_dim
225237

226238
# todo: prevent time information being lost after warping
227239
# Applying time index from series1 to series2 (take_dates = True) is disabled for consistency reasons
+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from darts.utils.torch import MonteCarloDropout
5+
6+
7+
class CustomFeedForwardEncoderLayer(nn.TransformerEncoderLayer):
8+
"""Overwrites the PyTorch TransformerEncoderLayer to use Darts' Position-wise Feed-Forward variants."""
9+
10+
def __init__(self, ffn: nn.Module, dropout: float, *args, **kwargs):
11+
"""
12+
Parameters
13+
----------
14+
ffn
15+
One of Darts' Position-wise Feed-Forward Network variants from darts.models.components.glu_variants
16+
dropout
17+
Fraction of neurons affected by Dropout (default=0.1).
18+
args
19+
positional arguments from torch.nn.TransformerEncoderLayer.
20+
kwargs
21+
keyword arguments from torch.nn.TransformerEncoderLayer. `activation` will have no effect.
22+
"""
23+
super().__init__(*args, **kwargs)
24+
self.ffn = ffn
25+
self.dropout = MonteCarloDropout(dropout)
26+
27+
# overwrite the feed forward block
28+
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
29+
x = self.ffn(x)
30+
return self.dropout(x)
31+
32+
33+
class CustomFeedForwardDecoderLayer(nn.TransformerDecoderLayer):
34+
"""Overwrites the PyTorch TransformerDecoderLayer to use Darts' custom Position Wise Feed Forward Layers."""
35+
36+
def __init__(self, ffn: nn.Module, dropout: float, *args, **kwargs):
37+
"""
38+
Parameters
39+
----------
40+
ffn
41+
One of Darts' Position-wise Feed-Forward Network variants from darts.models.components.glu_variants
42+
dropout
43+
Fraction of neurons affected by Dropout (default=0.1).
44+
args
45+
positional arguments from torch.nn.TransformerEncoderLayer.
46+
kwargs
47+
keyword arguments from torch.nn.TransformerEncoderLayer. `activation` will have no effect.
48+
"""
49+
super().__init__(*args, **kwargs)
50+
self.ffn = ffn
51+
self.dropout = MonteCarloDropout(dropout)
52+
53+
# overwrite the feed forward block
54+
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
55+
x = self.ffn(x)
56+
return self.dropout(x)

darts/models/forecasting/torch_forecasting_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,7 @@ def load_model(path: str) -> "TorchForecastingModel":
13551355
path_ptl_ckpt = base_path + "_ptl-ckpt.pth.tar"
13561356
if os.path.exists(path_ptl_ckpt):
13571357
model.model = model.model.__class__.load_from_checkpoint(path_ptl_ckpt)
1358-
model.trainer = model.model.trainer
1358+
model.trainer = None
13591359

13601360
return model
13611361

darts/models/forecasting/transformer_model.py

+66-8
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
import torch
1010
import torch.nn as nn
1111

12-
from darts.logging import get_logger, raise_if_not
12+
from darts.logging import get_logger, raise_if, raise_if_not
1313
from darts.models.components import glu_variants
1414
from darts.models.components.glu_variants import GLU_FFN
15+
from darts.models.components.transformer import (
16+
CustomFeedForwardDecoderLayer,
17+
CustomFeedForwardEncoderLayer,
18+
)
1519
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
1620
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
1721

@@ -22,6 +26,34 @@
2226
FFN = GLU_FFN + BUILT_IN
2327

2428

29+
def _generate_coder(
30+
d_model, dim_ff, dropout, nhead, num_layers, coder_cls, layer_cls, ffn_cls
31+
):
32+
"""Generates an Encoder or Decoder with one of Darts' Feed-forward Network variants.
33+
Parameters
34+
----------
35+
coder_cls
36+
Either `torch.nn.TransformerEncoder` or `...TransformerDecoder`
37+
layer_cls
38+
Either `darts.models.components.transformer.CustomFeedForwardEncoderLayer` or
39+
`...CustomFeedForwardDecoderLayer`
40+
ffn_cls
41+
One of Darts' Position-wise Feed-Forward Network variants `from darts.models.components.glu_variants`
42+
"""
43+
layer = layer_cls(
44+
ffn=ffn_cls(d_model=d_model, d_ff=dim_ff, dropout=dropout),
45+
dropout=dropout,
46+
d_model=d_model,
47+
nhead=nhead,
48+
dim_feedforward=dim_ff,
49+
)
50+
return coder_cls(
51+
layer,
52+
num_layers=num_layers,
53+
norm=nn.LayerNorm(d_model),
54+
)
55+
56+
2557
# This implementation of positional encoding is taken from the PyTorch documentation:
2658
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
2759
class _PositionalEncoding(nn.Module):
@@ -142,13 +174,39 @@ def __init__(
142174

143175
raise_if_not(activation in FFN, f"'{activation}' is not in {FFN}")
144176
if activation in GLU_FFN:
145-
# use glu variant feedforward layers
146-
self.activation = getattr(glu_variants, activation)(
147-
d_model=d_model, d_ff=dim_feedforward, dropout=dropout
177+
raise_if(
178+
custom_encoder is not None or custom_decoder is not None,
179+
"Cannot use `custom_encoder` or `custom_decoder` along with an `activation` from "
180+
f"{GLU_FFN}",
181+
logger=logger,
182+
)
183+
# use glu variant feed-forward layers
184+
ffn_cls = getattr(glu_variants, activation)
185+
186+
# custom feed-forward layers have activation built-in. reset activation
187+
activation = None
188+
189+
custom_encoder = _generate_coder(
190+
d_model,
191+
dim_feedforward,
192+
dropout,
193+
nhead,
194+
num_encoder_layers,
195+
nn.TransformerEncoder,
196+
CustomFeedForwardEncoderLayer,
197+
ffn_cls,
198+
)
199+
200+
custom_decoder = _generate_coder(
201+
d_model,
202+
dim_feedforward,
203+
dropout,
204+
nhead,
205+
num_decoder_layers,
206+
nn.TransformerDecoder,
207+
CustomFeedForwardDecoderLayer,
208+
ffn_cls,
148209
)
149-
else:
150-
# use nn.Transformer built in feedforward layers
151-
self.activation = activation
152210

153211
# Defining the Transformer module
154212
self.transformer = nn.Transformer(
@@ -158,7 +216,7 @@ def __init__(
158216
num_decoder_layers=num_decoder_layers,
159217
dim_feedforward=dim_feedforward,
160218
dropout=dropout,
161-
activation=self.activation,
219+
activation=activation,
162220
custom_encoder=custom_encoder,
163221
custom_decoder=custom_decoder,
164222
)

darts/tests/models/forecasting/test_transformer_model.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
logger = get_logger(__name__)
1212

1313
try:
14+
import torch.nn as nn
15+
16+
from darts.models.components.transformer import (
17+
CustomFeedForwardDecoderLayer,
18+
CustomFeedForwardEncoderLayer,
19+
)
1420
from darts.models.forecasting.transformer_model import (
1521
TransformerModel,
1622
_TransformerModule,
@@ -118,14 +124,28 @@ def test_activations(self):
118124
)
119125
model1.fit(self.series, epochs=1)
120126

121-
# internal activation function
127+
# internal activation function uses PyTorch TransformerEncoderLayer
122128
model2 = TransformerModel(
123129
input_chunk_length=1, output_chunk_length=1, activation="gelu"
124130
)
125131
model2.fit(self.series, epochs=1)
132+
assert isinstance(
133+
model2.model.transformer.encoder.layers[0], nn.TransformerEncoderLayer
134+
)
135+
assert isinstance(
136+
model2.model.transformer.decoder.layers[0], nn.TransformerDecoderLayer
137+
)
126138

127-
# glue variant FFN
139+
# glue variant FFN uses our custom _FeedForwardEncoderLayer
128140
model3 = TransformerModel(
129141
input_chunk_length=1, output_chunk_length=1, activation="SwiGLU"
130142
)
131143
model3.fit(self.series, epochs=1)
144+
assert isinstance(
145+
model3.model.transformer.encoder.layers[0],
146+
CustomFeedForwardEncoderLayer,
147+
)
148+
assert isinstance(
149+
model3.model.transformer.decoder.layers[0],
150+
CustomFeedForwardDecoderLayer,
151+
)

darts/tests/test_timeseries.py

+25
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,31 @@ def test_integer_indexing(self):
105105
list(indexed_ts.time_index) == list(pd.RangeIndex(2, 7, step=1))
106106
)
107107

108+
def test_univariate_component(self):
109+
series = TimeSeries.from_values(np.array([10, 20, 30])).with_columns_renamed(
110+
"0", "component"
111+
)
112+
mseries = concatenate([series] * 3, axis="component")
113+
mseries = mseries.with_hierarchy(
114+
{"component_1": ["component"], "component_2": ["component"]}
115+
)
116+
117+
static_cov = pd.DataFrame(
118+
{"dim0": [1, 2, 3], "dim1": [-2, -1, 0], "dim2": [0.0, 0.1, 0.2]}
119+
)
120+
121+
mseries = mseries.with_static_covariates(static_cov)
122+
123+
for univ_series in [
124+
mseries.univariate_component(1),
125+
mseries.univariate_component("component_1"),
126+
]:
127+
# hierarchy should be dropped
128+
self.assertIsNone(univ_series.hierarchy)
129+
130+
# only the right static covariate column should be retained
131+
self.assertEqual(univ_series.static_covariates.sum().sum(), 1.1)
132+
108133
def test_column_names(self):
109134
# test the column names resolution
110135
columns_before = [

darts/timeseries.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ def from_dataframe(
649649
else:
650650
raise_if_not(
651651
isinstance(df.index, VALID_INDEX_TYPES),
652-
"If time_col is not specified, the DataFrame must be indexed either with"
652+
"If time_col is not specified, the DataFrame must be indexed either with "
653653
"a DatetimeIndex, or with a RangeIndex.",
654654
logger,
655655
)
@@ -2702,6 +2702,9 @@ def univariate_component(self, index: Union[str, int]) -> "TimeSeries":
27022702
Retrieve one of the components of the series
27032703
and return it as new univariate ``TimeSeries`` instance.
27042704
2705+
This drops the hierarchy (if any), and retains only the relevant static
2706+
covariates column.
2707+
27052708
Parameters
27062709
----------
27072710
index
@@ -2713,11 +2716,8 @@ def univariate_component(self, index: Union[str, int]) -> "TimeSeries":
27132716
TimeSeries
27142717
A new univariate TimeSeries instance.
27152718
"""
2716-
if isinstance(index, int):
2717-
new_xa = self._xa.isel(component=index).expand_dims(DIMS[1], axis=1)
2718-
else:
2719-
new_xa = self._xa.sel(component=index).expand_dims(DIMS[1], axis=1)
2720-
return self.__class__(new_xa)
2719+
2720+
return self[index if isinstance(index, str) else self.components[index]]
27212721

27222722
def add_datetime_attribute(
27232723
self, attribute, one_hot: bool = False, cyclic: bool = False

docs/userguide/timeseries.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ In addition, some models can work on *multiple time series*, meaning that they c
3333

3434
* **Example of a multivariate series:** The blood pressure and heart rate of a single patient over time (one multivariate series with 2 components).
3535

36-
* **Example of multiple series:** The blood pressure and heart rate of multiple patients; potentially measured at different times for different patients (one univariate series per patient).
36+
* **Example of multiple series:** The blood pressure and heart rate of multiple patients; potentially measured at different times for different patients (one multivariate series with 2 components per patient).
3737

3838

3939
### Should I use a multivariate series or multiple series for my problem?
@@ -50,9 +50,9 @@ In Darts, probabilistic forecasts are represented by drawing Monte Carlo samples
5050
## Creating `TimeSeries`
5151
`TimeSeries` objects can be created using factory methods, for example:
5252

53-
* [TimeSeries.from_dataframe()](https://unit8co.github.io/darts/generated_api/darts.timeseries.html#darts.timeseries.TimeSeries.from_dataframe) can create `TimeSeries` from a Pandas Dataframe having one or several columns representing values (several columns would correspond to a multivariate series).
53+
* [TimeSeries.from_dataframe()](https://unit8co.github.io/darts/generated_api/darts.timeseries.html#darts.timeseries.TimeSeries.from_dataframe) can create `TimeSeries` from a Pandas Dataframe having one or several columns representing values (columns correspond to components, and several columns would correspond to a multivariate series).
5454

55-
* [TimeSeries.from_values()](https://unit8co.github.io/darts/generated_api/darts.timeseries.html#darts.timeseries.TimeSeries.from_values) can create `TimeSeries` from a 2-D or 3-D NumPy array. It will generate an integer-based time index (of type `pandas.RangeIndex`). 2-D corresponds to deterministic (potentially multivariate) series, and 3-D to stochastic series.
55+
* [TimeSeries.from_values()](https://unit8co.github.io/darts/generated_api/darts.timeseries.html#darts.timeseries.TimeSeries.from_values) can create `TimeSeries` from a 1-D, 2-D or 3-D NumPy array. It will generate an integer-based time index (of type `pandas.RangeIndex`). 1-D corresponds to univariate deterministic series, 2-D to multivariate deterministic series, and 3-D to multivariate stochastic series.
5656

5757
* [TimeSeries.from_times_and_values()](https://unit8co.github.io/darts/generated_api/darts.timeseries.html#darts.timeseries.TimeSeries.from_times_and_values) is similar to `TimeSeries.from_values()` but also accepts a time index.
5858

@@ -67,7 +67,7 @@ my_multivariate_series = concatenate([series1, series2, ...], axis=1)
6767
produces a multivariate series from some series that share the same time axis.
6868

6969
## Implementation
70-
Behind the scenes, `TimeSeries` is wrapping around a 3-dimensional `xarray.DataArray` object. The dimensions are *(time, component, sample)*, where the size of the *component* dimension is larger than 1 for multivariate series and the size of the *sample* dimension is larger than 1 for stochastic series. The `DataArray` is itself backed by a a 3-dimensional NumPy array, and it has a time index (either `pandas.DatetimeIndex` or `pandas.RangeIndex`) on the *time* dimension and another `pandas.Index` on the *component* (or "columns") dimension. `TimeSeries` is intended to be immutable.
70+
Behind the scenes, `TimeSeries` is wrapping around a 3-dimensional `xarray.DataArray` object. The dimensions are *(time, component, sample)*, where the size of the *component* dimension is larger than 1 for multivariate series and the size of the *sample* dimension is larger than 1 for stochastic series. The `DataArray` is itself backed by a 3-dimensional NumPy array, and it has a time index (either `pandas.DatetimeIndex` or `pandas.RangeIndex`) on the *time* dimension and another `pandas.Index` on the *component* (or "columns") dimension. `TimeSeries` is intended to be immutable and most operations return new `TimeSeries` objects.
7171

7272
## Exporting data from a `TimeSeries`
7373
`TimeSeries` objects offer a few ways to export the data, for example:

requirements/core.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ prophet>=1.1
1212
requests>=2.22.0
1313
scikit-learn>=1.0.1
1414
scipy>=1.3.2
15-
statsforecast>=0.5.2
15+
statsforecast==0.6.0
1616
statsmodels>=0.13.0
1717
tbats>=1.1.0
1818
tqdm>=4.60.0

0 commit comments

Comments
 (0)