Skip to content

Commit 2fc99b9

Browse files
gdevos010Greg DeVoshrzn
authored
New alternatives to layer norm (#1114)
* layer norm vairants * fixed default * license, changelog and test * powerNorm * arxiv link * typos * typo * adding PowerNorm to docs * PR comments * PR comments * PR comments * PR comments * removed ScaleNorm. merged files * added LayerNormNoBias and removed powernorm * custom norm_type * new norm layers for transformer model * fix test * Update darts/models/forecasting/transformer_model.py Co-authored-by: Julien Herzen <[email protected]> * Update darts/models/forecasting/tft_model.py Co-authored-by: Julien Herzen <[email protected]> * layer norm test Co-authored-by: Greg DeVos <[email protected]> Co-authored-by: Julien Herzen <[email protected]> Co-authored-by: Julien Herzen <[email protected]>
1 parent 4a522a0 commit 2fc99b9

File tree

9 files changed

+300
-35
lines changed

9 files changed

+300
-35
lines changed

CHANGELOG.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11

22
# Changelog
33

4-
Darts is still in an early development phase and we cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "&#x1F534;".
4+
Darts is still in an early development phase, and we cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "&#x1F534;".
55

66
## [Unreleased](https://github.com/unit8co/darts/tree/master)
7-
87
- Added support for retraining model(s) every `n` iteration and on custom condition in `historical_forecasts` method of `ForecastingModel` abstract class. Addressed issues [#135](https://github.com/unit8co/darts/issues/135) and [#623](https://github.com/unit8co/darts/issues/623) by [Francesco Bruzzesi](https://github.com/fbruzzesi).
8+
- New LayerNorm alternatives, RMSNorm and LayerNormNoBias [#1113](https://github.com/unit8co/darts/issues/1113) by [Greg DeVos](https://github.com/gdevos010).
99

1010
[Full Changelog](https://github.com/unit8co/darts/compare/0.21.0...master)
1111

@@ -49,7 +49,7 @@ Darts is still in an early development phase and we cannot always guarantee back
4949
- Added support for static covariates in `TimeSeries` class. [#966](https://github.com/unit8co/darts/pull/966) by [Dennis Bader](https://github.com/dennisbader).
5050
- Added support for static covariates in TFT model. [#966](https://github.com/unit8co/darts/pull/966) by [Dennis Bader](https://github.com/dennisbader).
5151
- Support for storing hierarchy of components in `TimeSeries` (in view of hierarchical reconciliation) [#1012](https://github.com/unit8co/darts/pull/1012) by [Julien Herzen](https://github.com/hrzn).
52-
- New Reconciliation transformers for forececast reconciliation: bottom up, top down and MinT. [#1012](https://github.com/unit8co/darts/pull/1012) by [Julien Herzen](https://github.com/hrzn).
52+
- New Reconciliation transformers for forecast reconciliation: bottom up, top down and MinT. [#1012](https://github.com/unit8co/darts/pull/1012) by [Julien Herzen](https://github.com/hrzn).
5353
- Added support for Monte Carlo Dropout, as a way to capture model uncertainty with torch models at inference time. [#1013](https://github.com/unit8co/darts/pull/1013) by [Julien Herzen](https://github.com/hrzn).
5454
- New datasets: ETT and Electricity. [#617](https://github.com/unit8co/darts/pull/617)
5555
by [Greg DeVos](https://github.com/gdevos010)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
MIT License
3+
4+
Copyright (c) 2020 Phil Wang
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in all
14+
copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
SOFTWARE.
23+
"""
24+
25+
import torch
26+
import torch.nn as nn
27+
28+
29+
class RMSNorm(nn.Module):
30+
"""An alternate to layer normalization, without mean centering and the learned bias [1]
31+
32+
References
33+
----------
34+
.. [1] Zhang, Biao, and Rico Sennrich. "Root mean square layer normalization." Advances in Neural Information
35+
Processing Systems 32 (2019).
36+
"""
37+
38+
def __init__(self, dim, eps=1e-8):
39+
super().__init__()
40+
self.scale = dim**-0.5
41+
self.eps = eps
42+
self.g = nn.Parameter(torch.ones(dim))
43+
44+
def forward(self, x):
45+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
46+
return x / norm.clamp(min=self.eps) * self.g
47+
48+
49+
class LayerNormNoBias(nn.LayerNorm):
50+
def __init__(self, input_size, **kwargs):
51+
super().__init__(input_size, elementwise_affine=False, **kwargs)
52+
53+
54+
class LayerNorm(nn.LayerNorm):
55+
def __init__(self, input_size, **kwargs) -> None:
56+
super().__init__(input_size, **kwargs)

darts/models/forecasting/nhits.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def __init__(
158158

159159
self.layers = nn.Sequential(*layers)
160160

161-
# Fully connected layer producing forecast/backcast expansion coeffcients (waveform generator parameters).
161+
# Fully connected layer producing forecast/backcast expansion coefficients (waveform generator parameters).
162162
# The coefficients are emitted for each parameter of the likelihood for the forecast.
163163
self.backcast_linear_layer = nn.Linear(
164164
in_features=layer_width, out_features=n_theta_backcast
@@ -413,7 +413,7 @@ def __init__(
413413
self.stacks = nn.ModuleList(self.stacks_list)
414414

415415
# setting the last backcast "branch" to be not trainable (without next block/stack, it doesn't need to be
416-
# backpropagated). Removing this lines would cause logtensorboard to crash, since no gradient is stored
416+
# backpropagated). Removing this line would cause logtensorboard to crash, since no gradient is stored
417417
# on this params (the last block backcast is not part of the final output of the net).
418418
self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False)
419419

@@ -476,7 +476,7 @@ def __init__(
476476
477477
N-HiTS is similar to N-BEATS (implemented in :class:`NBEATSModel`),
478478
but attempts to provide better performance at lower computational cost by introducing
479-
multi-rate sampling of the inputs and mulit-scale interpolation of the outputs.
479+
multi-rate sampling of the inputs and multi-scale interpolation of the outputs.
480480
481481
Similar to :class:`NBEATSModel`, in addition to the univariate version presented in the paper,
482482
this implementation also supports multivariate series (and covariates) by flattening the model inputs
@@ -489,7 +489,7 @@ def __init__(
489489
This parameter can be a tuple of tuples, of size (num_stacks x num_blocks), specifying the kernel
490490
size for each block in each stack. If left to ``None``, some default values will be used based on
491491
``input_chunk_length``.
492-
Similarly, the multi-scale interpolation is controled by ``n_freq_downsample``, which gives the
492+
Similarly, the multi-scale interpolation is controlled by ``n_freq_downsample``, which gives the
493493
downsampling factors to be used in each block of each stack. If left to ``None``, some default
494494
values will be used based on the ``output_chunk_length``.
495495
@@ -545,7 +545,7 @@ def __init__(
545545
The PyTorch optimizer class to be used. Default: ``torch.optim.Adam``.
546546
optimizer_kwargs
547547
Optionally, some keyword arguments for the PyTorch optimizer (e.g., ``{'lr': 1e-3}``
548-
for specifying a learning rate). Otherwise the default values of the selected ``optimizer_cls``
548+
for specifying a learning rate). Otherwise, the default values of the selected ``optimizer_cls``
549549
will be used. Default: ``None``.
550550
lr_scheduler_cls
551551
Optionally, the PyTorch learning rate scheduler class to be used. Specifying ``None`` corresponds

darts/models/forecasting/tft_model.py

+40-6
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from torch.nn import LSTM as _LSTM
1313

1414
from darts import TimeSeries
15-
from darts.logging import get_logger, raise_if, raise_if_not
16-
from darts.models.components import glu_variants
15+
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
16+
from darts.models.components import glu_variants, layer_norm_variants
1717
from darts.models.components.glu_variants import GLU_FFN
1818
from darts.models.forecasting.pl_forecasting_module import PLMixedCovariatesModule
1919
from darts.models.forecasting.tft_submodels import (
@@ -55,6 +55,7 @@ def __init__(
5555
categorical_embedding_sizes: Dict[str, Tuple[int, int]],
5656
dropout: float,
5757
add_relative_index: bool,
58+
norm_type: Union[str, nn.Module],
5859
**kwargs,
5960
):
6061

@@ -102,6 +103,8 @@ def __init__(
102103
likelihood
103104
The likelihood model to be used for probabilistic forecasts. By default, the TFT uses
104105
a ``QuantileRegression`` likelihood.
106+
norm_type: str | nn.Module
107+
The type of LayerNorm variant to use.
105108
**kwargs
106109
all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class.
107110
"""
@@ -121,6 +124,16 @@ def __init__(
121124
self.dropout = dropout
122125
self.add_relative_index = add_relative_index
123126

127+
if isinstance(norm_type, str):
128+
try:
129+
self.layer_norm = getattr(layer_norm_variants, norm_type)
130+
except AttributeError:
131+
raise_log(
132+
AttributeError("please provide a valid layer norm type"),
133+
)
134+
else:
135+
self.layer_norm = norm_type
136+
124137
# initialize last batch size to check if new mask needs to be generated
125138
self.batch_size_last = -1
126139
self.attention_mask = None
@@ -173,6 +186,7 @@ def __init__(
173186
prescalers=self.prescalers_linear,
174187
single_variable_grns={},
175188
context_size=None, # no context for static variables
189+
layer_norm=self.layer_norm,
176190
)
177191

178192
# variable selection for encoder and decoder
@@ -192,6 +206,7 @@ def __init__(
192206
context_size=self.hidden_size,
193207
prescalers=self.prescalers_linear,
194208
single_variable_grns={},
209+
layer_norm=self.layer_norm,
195210
)
196211

197212
self.decoder_vsn = _VariableSelectionNetwork(
@@ -202,6 +217,7 @@ def __init__(
202217
context_size=self.hidden_size,
203218
prescalers=self.prescalers_linear,
204219
single_variable_grns={},
220+
layer_norm=self.layer_norm,
205221
)
206222

207223
# static encoders
@@ -211,6 +227,7 @@ def __init__(
211227
hidden_size=self.hidden_size,
212228
output_size=self.hidden_size,
213229
dropout=self.dropout,
230+
layer_norm=self.layer_norm,
214231
)
215232

216233
# for hidden state of the lstm
@@ -219,6 +236,7 @@ def __init__(
219236
hidden_size=self.hidden_size,
220237
output_size=self.hidden_size,
221238
dropout=self.dropout,
239+
layer_norm=self.layer_norm,
222240
)
223241

224242
# for cell state of the lstm
@@ -227,6 +245,7 @@ def __init__(
227245
hidden_size=self.hidden_size,
228246
output_size=self.hidden_size,
229247
dropout=self.dropout,
248+
layer_norm=self.layer_norm,
230249
)
231250

232251
# for post lstm static enrichment
@@ -235,6 +254,7 @@ def __init__(
235254
hidden_size=self.hidden_size,
236255
output_size=self.hidden_size,
237256
dropout=self.dropout,
257+
layer_norm=self.layer_norm,
238258
)
239259

240260
# lstm encoder (history) and decoder (future) for local processing
@@ -255,7 +275,9 @@ def __init__(
255275
)
256276

257277
# post lstm GateAddNorm
258-
self.post_lstm_gan = _GateAddNorm(input_size=self.hidden_size, dropout=dropout)
278+
self.post_lstm_gan = _GateAddNorm(
279+
input_size=self.hidden_size, dropout=dropout, layer_norm=self.layer_norm
280+
)
259281

260282
# static enrichment and processing past LSTM
261283
self.static_enrichment_grn = _GatedResidualNetwork(
@@ -264,6 +286,7 @@ def __init__(
264286
output_size=self.hidden_size,
265287
dropout=self.dropout,
266288
context_size=self.hidden_size,
289+
layer_norm=self.layer_norm,
267290
)
268291

269292
# attention for long-range processing
@@ -272,14 +295,17 @@ def __init__(
272295
n_head=self.num_attention_heads,
273296
dropout=self.dropout,
274297
)
275-
self.post_attn_gan = _GateAddNorm(self.hidden_size, dropout=self.dropout)
298+
self.post_attn_gan = _GateAddNorm(
299+
self.hidden_size, dropout=self.dropout, layer_norm=self.layer_norm
300+
)
276301

277302
if self.feed_forward == "GatedResidualNetwork":
278303
self.feed_forward_block = _GatedResidualNetwork(
279304
self.hidden_size,
280305
self.hidden_size,
281306
self.hidden_size,
282307
dropout=self.dropout,
308+
layer_norm=self.layer_norm,
283309
)
284310
else:
285311
raise_if_not(
@@ -293,7 +319,9 @@ def __init__(
293319
)
294320

295321
# output processing -> no dropout at this late stage
296-
self.pre_output_gan = _GateAddNorm(self.hidden_size, dropout=None)
322+
self.pre_output_gan = _GateAddNorm(
323+
self.hidden_size, dropout=None, layer_norm=self.layer_norm
324+
)
297325

298326
self.output_layer = nn.Linear(self.hidden_size, self.n_targets * self.loss_size)
299327

@@ -637,6 +665,7 @@ def __init__(
637665
add_relative_index: bool = False,
638666
loss_fn: Optional[nn.Module] = None,
639667
likelihood: Optional[Likelihood] = None,
668+
norm_type: Union[str, nn.Module] = "LayerNorm",
640669
**kwargs,
641670
):
642671
"""Temporal Fusion Transformers (TFT) for Interpretable Time Series Forecasting.
@@ -699,13 +728,16 @@ def __init__(
699728
This allows to use the TFTModel without having to pass future_covariates to :func:`fit()` and
700729
:func:`train()`. It gives a value to the position of each step from input and output chunk relative
701730
to the prediction point. The values are normalized with ``input_chunk_length``.
702-
loss_fn
731+
loss_fn: nn.Module
703732
PyTorch loss function used for training. By default, the TFT model is probabilistic and uses a
704733
``likelihood`` instead (``QuantileRegression``). To make the model deterministic, you can set the `
705734
`likelihood`` to None and give a ``loss_fn`` argument.
706735
likelihood
707736
The likelihood model to be used for probabilistic forecasts. By default, the TFT uses
708737
a ``QuantileRegression`` likelihood.
738+
norm_type: str | nn.Module
739+
The type of LayerNorm variant to use. Default: ``LayerNorm``. Available options are
740+
["LayerNorm", "RMSNorm", "LayerNormNoBias"], or provide a custom nn.Module.
709741
**kwargs
710742
Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
711743
Darts' :class:`TorchForecastingModel`.
@@ -862,6 +894,7 @@ def __init__(
862894
)
863895
self.add_relative_index = add_relative_index
864896
self.output_dim: Optional[Tuple[int, int]] = None
897+
self.norm_type = norm_type
865898

866899
def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Module:
867900
"""
@@ -1049,6 +1082,7 @@ def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Modu
10491082
hidden_continuous_size=self.hidden_continuous_size,
10501083
categorical_embedding_sizes=self.categorical_embedding_sizes,
10511084
add_relative_index=self.add_relative_index,
1085+
norm_type=self.norm_type,
10521086
**self.pl_module_params,
10531087
)
10541088

0 commit comments

Comments
 (0)