9
9
import torch
10
10
import torch .nn as nn
11
11
12
- from darts .logging import get_logger , raise_if_not
12
+ from darts .logging import get_logger , raise_if , raise_if_not
13
13
from darts .models .components import glu_variants
14
14
from darts .models .components .glu_variants import GLU_FFN
15
+ from darts .models .components .transformer import (
16
+ CustomFeedForwardDecoderLayer ,
17
+ CustomFeedForwardEncoderLayer ,
18
+ )
15
19
from darts .models .forecasting .pl_forecasting_module import PLPastCovariatesModule
16
20
from darts .models .forecasting .torch_forecasting_model import PastCovariatesTorchModel
17
21
22
26
FFN = GLU_FFN + BUILT_IN
23
27
24
28
29
+ def _generate_coder (
30
+ d_model , dim_ff , dropout , nhead , num_layers , coder_cls , layer_cls , ffn_cls
31
+ ):
32
+ """Generates an Encoder or Decoder with one of Darts' Feed-forward Network variants.
33
+ Parameters
34
+ ----------
35
+ coder_cls
36
+ Either `torch.nn.TransformerEncoder` or `...TransformerDecoder`
37
+ layer_cls
38
+ Either `darts.models.components.transformer.CustomFeedForwardEncoderLayer` or
39
+ `...CustomFeedForwardDecoderLayer`
40
+ ffn_cls
41
+ One of Darts' Position-wise Feed-Forward Network variants `from darts.models.components.glu_variants`
42
+ """
43
+ layer = layer_cls (
44
+ ffn = ffn_cls (d_model = d_model , d_ff = dim_ff , dropout = dropout ),
45
+ dropout = dropout ,
46
+ d_model = d_model ,
47
+ nhead = nhead ,
48
+ dim_feedforward = dim_ff ,
49
+ )
50
+ return coder_cls (
51
+ layer ,
52
+ num_layers = num_layers ,
53
+ norm = nn .LayerNorm (d_model ),
54
+ )
55
+
56
+
25
57
# This implementation of positional encoding is taken from the PyTorch documentation:
26
58
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
27
59
class _PositionalEncoding (nn .Module ):
@@ -142,13 +174,39 @@ def __init__(
142
174
143
175
raise_if_not (activation in FFN , f"'{ activation } ' is not in { FFN } " )
144
176
if activation in GLU_FFN :
145
- # use glu variant feedforward layers
146
- self .activation = getattr (glu_variants , activation )(
147
- d_model = d_model , d_ff = dim_feedforward , dropout = dropout
177
+ raise_if (
178
+ custom_encoder is not None or custom_decoder is not None ,
179
+ "Cannot use `custom_encoder` or `custom_decoder` along with an `activation` from "
180
+ f"{ GLU_FFN } " ,
181
+ logger = logger ,
182
+ )
183
+ # use glu variant feed-forward layers
184
+ ffn_cls = getattr (glu_variants , activation )
185
+
186
+ # custom feed-forward layers have activation built-in. reset activation
187
+ activation = None
188
+
189
+ custom_encoder = _generate_coder (
190
+ d_model ,
191
+ dim_feedforward ,
192
+ dropout ,
193
+ nhead ,
194
+ num_encoder_layers ,
195
+ nn .TransformerEncoder ,
196
+ CustomFeedForwardEncoderLayer ,
197
+ ffn_cls ,
198
+ )
199
+
200
+ custom_decoder = _generate_coder (
201
+ d_model ,
202
+ dim_feedforward ,
203
+ dropout ,
204
+ nhead ,
205
+ num_decoder_layers ,
206
+ nn .TransformerDecoder ,
207
+ CustomFeedForwardDecoderLayer ,
208
+ ffn_cls ,
148
209
)
149
- else :
150
- # use nn.Transformer built in feedforward layers
151
- self .activation = activation
152
210
153
211
# Defining the Transformer module
154
212
self .transformer = nn .Transformer (
@@ -158,7 +216,7 @@ def __init__(
158
216
num_decoder_layers = num_decoder_layers ,
159
217
dim_feedforward = dim_feedforward ,
160
218
dropout = dropout ,
161
- activation = self . activation ,
219
+ activation = activation ,
162
220
custom_encoder = custom_encoder ,
163
221
custom_decoder = custom_decoder ,
164
222
)
0 commit comments