Skip to content

Commit 934b204

Browse files
authored
Merge branch 'master' into feat/apply-arima-to-new-ts
2 parents f4dda3d + eb18103 commit 934b204

File tree

5 files changed

+146
-12
lines changed

5 files changed

+146
-12
lines changed
+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from darts.utils.torch import MonteCarloDropout
5+
6+
7+
class CustomFeedForwardEncoderLayer(nn.TransformerEncoderLayer):
8+
"""Overwrites the PyTorch TransformerEncoderLayer to use Darts' Position-wise Feed-Forward variants."""
9+
10+
def __init__(self, ffn: nn.Module, dropout: float, *args, **kwargs):
11+
"""
12+
Parameters
13+
----------
14+
ffn
15+
One of Darts' Position-wise Feed-Forward Network variants from darts.models.components.glu_variants
16+
dropout
17+
Fraction of neurons affected by Dropout (default=0.1).
18+
args
19+
positional arguments from torch.nn.TransformerEncoderLayer.
20+
kwargs
21+
keyword arguments from torch.nn.TransformerEncoderLayer. `activation` will have no effect.
22+
"""
23+
super().__init__(*args, **kwargs)
24+
self.ffn = ffn
25+
self.dropout = MonteCarloDropout(dropout)
26+
27+
# overwrite the feed forward block
28+
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
29+
x = self.ffn(x)
30+
return self.dropout(x)
31+
32+
33+
class CustomFeedForwardDecoderLayer(nn.TransformerDecoderLayer):
34+
"""Overwrites the PyTorch TransformerDecoderLayer to use Darts' custom Position Wise Feed Forward Layers."""
35+
36+
def __init__(self, ffn: nn.Module, dropout: float, *args, **kwargs):
37+
"""
38+
Parameters
39+
----------
40+
ffn
41+
One of Darts' Position-wise Feed-Forward Network variants from darts.models.components.glu_variants
42+
dropout
43+
Fraction of neurons affected by Dropout (default=0.1).
44+
args
45+
positional arguments from torch.nn.TransformerEncoderLayer.
46+
kwargs
47+
keyword arguments from torch.nn.TransformerEncoderLayer. `activation` will have no effect.
48+
"""
49+
super().__init__(*args, **kwargs)
50+
self.ffn = ffn
51+
self.dropout = MonteCarloDropout(dropout)
52+
53+
# overwrite the feed forward block
54+
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
55+
x = self.ffn(x)
56+
return self.dropout(x)

darts/models/forecasting/torch_forecasting_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,7 @@ def load_model(path: str) -> "TorchForecastingModel":
13551355
path_ptl_ckpt = base_path + "_ptl-ckpt.pth.tar"
13561356
if os.path.exists(path_ptl_ckpt):
13571357
model.model = model.model.__class__.load_from_checkpoint(path_ptl_ckpt)
1358-
model.trainer = model.model.trainer
1358+
model.trainer = None
13591359

13601360
return model
13611361

darts/models/forecasting/transformer_model.py

+66-8
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
import torch
1010
import torch.nn as nn
1111

12-
from darts.logging import get_logger, raise_if_not
12+
from darts.logging import get_logger, raise_if, raise_if_not
1313
from darts.models.components import glu_variants
1414
from darts.models.components.glu_variants import GLU_FFN
15+
from darts.models.components.transformer import (
16+
CustomFeedForwardDecoderLayer,
17+
CustomFeedForwardEncoderLayer,
18+
)
1519
from darts.models.forecasting.pl_forecasting_module import PLPastCovariatesModule
1620
from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel
1721

@@ -22,6 +26,34 @@
2226
FFN = GLU_FFN + BUILT_IN
2327

2428

29+
def _generate_coder(
30+
d_model, dim_ff, dropout, nhead, num_layers, coder_cls, layer_cls, ffn_cls
31+
):
32+
"""Generates an Encoder or Decoder with one of Darts' Feed-forward Network variants.
33+
Parameters
34+
----------
35+
coder_cls
36+
Either `torch.nn.TransformerEncoder` or `...TransformerDecoder`
37+
layer_cls
38+
Either `darts.models.components.transformer.CustomFeedForwardEncoderLayer` or
39+
`...CustomFeedForwardDecoderLayer`
40+
ffn_cls
41+
One of Darts' Position-wise Feed-Forward Network variants `from darts.models.components.glu_variants`
42+
"""
43+
layer = layer_cls(
44+
ffn=ffn_cls(d_model=d_model, d_ff=dim_ff, dropout=dropout),
45+
dropout=dropout,
46+
d_model=d_model,
47+
nhead=nhead,
48+
dim_feedforward=dim_ff,
49+
)
50+
return coder_cls(
51+
layer,
52+
num_layers=num_layers,
53+
norm=nn.LayerNorm(d_model),
54+
)
55+
56+
2557
# This implementation of positional encoding is taken from the PyTorch documentation:
2658
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
2759
class _PositionalEncoding(nn.Module):
@@ -142,13 +174,39 @@ def __init__(
142174

143175
raise_if_not(activation in FFN, f"'{activation}' is not in {FFN}")
144176
if activation in GLU_FFN:
145-
# use glu variant feedforward layers
146-
self.activation = getattr(glu_variants, activation)(
147-
d_model=d_model, d_ff=dim_feedforward, dropout=dropout
177+
raise_if(
178+
custom_encoder is not None or custom_decoder is not None,
179+
"Cannot use `custom_encoder` or `custom_decoder` along with an `activation` from "
180+
f"{GLU_FFN}",
181+
logger=logger,
182+
)
183+
# use glu variant feed-forward layers
184+
ffn_cls = getattr(glu_variants, activation)
185+
186+
# custom feed-forward layers have activation built-in. reset activation
187+
activation = None
188+
189+
custom_encoder = _generate_coder(
190+
d_model,
191+
dim_feedforward,
192+
dropout,
193+
nhead,
194+
num_encoder_layers,
195+
nn.TransformerEncoder,
196+
CustomFeedForwardEncoderLayer,
197+
ffn_cls,
198+
)
199+
200+
custom_decoder = _generate_coder(
201+
d_model,
202+
dim_feedforward,
203+
dropout,
204+
nhead,
205+
num_decoder_layers,
206+
nn.TransformerDecoder,
207+
CustomFeedForwardDecoderLayer,
208+
ffn_cls,
148209
)
149-
else:
150-
# use nn.Transformer built in feedforward layers
151-
self.activation = activation
152210

153211
# Defining the Transformer module
154212
self.transformer = nn.Transformer(
@@ -158,7 +216,7 @@ def __init__(
158216
num_decoder_layers=num_decoder_layers,
159217
dim_feedforward=dim_feedforward,
160218
dropout=dropout,
161-
activation=self.activation,
219+
activation=activation,
162220
custom_encoder=custom_encoder,
163221
custom_decoder=custom_decoder,
164222
)

darts/tests/models/forecasting/test_transformer_model.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
logger = get_logger(__name__)
1212

1313
try:
14+
import torch.nn as nn
15+
16+
from darts.models.components.transformer import (
17+
CustomFeedForwardDecoderLayer,
18+
CustomFeedForwardEncoderLayer,
19+
)
1420
from darts.models.forecasting.transformer_model import (
1521
TransformerModel,
1622
_TransformerModule,
@@ -118,14 +124,28 @@ def test_activations(self):
118124
)
119125
model1.fit(self.series, epochs=1)
120126

121-
# internal activation function
127+
# internal activation function uses PyTorch TransformerEncoderLayer
122128
model2 = TransformerModel(
123129
input_chunk_length=1, output_chunk_length=1, activation="gelu"
124130
)
125131
model2.fit(self.series, epochs=1)
132+
assert isinstance(
133+
model2.model.transformer.encoder.layers[0], nn.TransformerEncoderLayer
134+
)
135+
assert isinstance(
136+
model2.model.transformer.decoder.layers[0], nn.TransformerDecoderLayer
137+
)
126138

127-
# glue variant FFN
139+
# glue variant FFN uses our custom _FeedForwardEncoderLayer
128140
model3 = TransformerModel(
129141
input_chunk_length=1, output_chunk_length=1, activation="SwiGLU"
130142
)
131143
model3.fit(self.series, epochs=1)
144+
assert isinstance(
145+
model3.model.transformer.encoder.layers[0],
146+
CustomFeedForwardEncoderLayer,
147+
)
148+
assert isinstance(
149+
model3.model.transformer.decoder.layers[0],
150+
CustomFeedForwardDecoderLayer,
151+
)

darts/timeseries.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ def from_dataframe(
649649
else:
650650
raise_if_not(
651651
isinstance(df.index, VALID_INDEX_TYPES),
652-
"If time_col is not specified, the DataFrame must be indexed either with"
652+
"If time_col is not specified, the DataFrame must be indexed either with "
653653
"a DatetimeIndex, or with a RangeIndex.",
654654
logger,
655655
)

0 commit comments

Comments
 (0)