Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New alternatives to layer norm #1114

Merged
merged 29 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c9f262b
layer norm vairants
Jul 28, 2022
04739da
fixed default
Jul 29, 2022
8dbdb5d
license, changelog and test
gdevos010 Aug 2, 2022
41d1055
Merge branch 'master' into altLayerNorm
gdevos010 Aug 5, 2022
45596cb
powerNorm
gdevos010 Aug 6, 2022
7d103ea
arxiv link
gdevos010 Aug 6, 2022
c9b5a1f
typos
gdevos010 Aug 6, 2022
7b8bd77
typo
gdevos010 Aug 7, 2022
16f946a
adding PowerNorm to docs
gdevos010 Aug 7, 2022
c68a4a6
PR comments
gdevos010 Aug 7, 2022
99a739b
PR comments
gdevos010 Aug 7, 2022
c465102
PR comments
gdevos010 Aug 7, 2022
89314d0
Merge branch 'master' into altLayerNorm
gdevos010 Aug 7, 2022
b0d6246
Merge branch 'master' into altLayerNorm
hrzn Aug 7, 2022
0e7d381
PR comments
gdevos010 Aug 7, 2022
13c58c1
Merge branch 'master' into altLayerNorm
gdevos010 Aug 7, 2022
0892842
removed ScaleNorm. merged files
Aug 8, 2022
c4b82bd
Merge branch 'master' into altLayerNorm
gdevos010 Aug 8, 2022
52831de
Merge branch 'master' into altLayerNorm
gdevos010 Aug 19, 2022
ef9e9e6
Merge branch 'master' into altLayerNorm
hrzn Aug 22, 2022
0000e2d
added LayerNormNoBias and removed powernorm
Aug 22, 2022
9dc066f
Merge branch 'master' into altLayerNorm
gdevos010 Aug 25, 2022
5354977
custom norm_type
Aug 25, 2022
476d666
new norm layers for transformer model
Aug 25, 2022
e5b2943
fix test
Aug 25, 2022
98b0771
Update darts/models/forecasting/transformer_model.py
gdevos010 Aug 27, 2022
f3094e9
Update darts/models/forecasting/tft_model.py
gdevos010 Aug 27, 2022
4518995
layer norm test
Aug 31, 2022
2692eac
Merge branch 'master' into altLayerNorm
gdevos010 Aug 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

# Changelog

Darts is still in an early development phase and we cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "🔴".
Darts is still in an early development phase, and we cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "🔴".

## [Unreleased](https://github.com/unit8co/darts/tree/master)

- New LayerNorm alternatives, RMSNorm and LayerNormNoBias [#1113](https://github.com/unit8co/darts/issues/1113) by [Greg DeVos](https://github.com/gdevos010).
[Full Changelog](https://github.com/unit8co/darts/compare/0.21.0...master)


Expand Down Expand Up @@ -47,7 +47,7 @@ Darts is still in an early development phase and we cannot always guarantee back
- Added support for static covariates in `TimeSeries` class. [#966](https://github.com/unit8co/darts/pull/966) by [Dennis Bader](https://github.com/dennisbader).
- Added support for static covariates in TFT model. [#966](https://github.com/unit8co/darts/pull/966) by [Dennis Bader](https://github.com/dennisbader).
- Support for storing hierarchy of components in `TimeSeries` (in view of hierarchical reconciliation) [#1012](https://github.com/unit8co/darts/pull/1012) by [Julien Herzen](https://github.com/hrzn).
- New Reconciliation transformers for forececast reconciliation: bottom up, top down and MinT. [#1012](https://github.com/unit8co/darts/pull/1012) by [Julien Herzen](https://github.com/hrzn).
- New Reconciliation transformers for forecast reconciliation: bottom up, top down and MinT. [#1012](https://github.com/unit8co/darts/pull/1012) by [Julien Herzen](https://github.com/hrzn).
- Added support for Monte Carlo Dropout, as a way to capture model uncertainty with torch models at inference time. [#1013](https://github.com/unit8co/darts/pull/1013) by [Julien Herzen](https://github.com/hrzn).
- New datasets: ETT and Electricity. [#617](https://github.com/unit8co/darts/pull/617)
by [Greg DeVos](https://github.com/gdevos010)
Expand Down
56 changes: 56 additions & 0 deletions darts/models/components/layer_norm_variants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
MIT License

Copyright (c) 2020 Phil Wang

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import torch
import torch.nn as nn


class RMSNorm(nn.Module):
"""An alternate to layer normalization, without mean centering and the learned bias [1]

References
----------
.. [1] Zhang, Biao, and Rico Sennrich. "Root mean square layer normalization." Advances in Neural Information
Processing Systems 32 (2019).
"""

def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))

def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g


class LayerNormNoBias(nn.LayerNorm):
def __init__(self, input_size, **kwargs):
super().__init__(input_size, elementwise_affine=False, **kwargs)


class LayerNorm(nn.LayerNorm):
def __init__(self, input_size, **kwargs) -> None:
super().__init__(input_size, **kwargs)
10 changes: 5 additions & 5 deletions darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(

self.layers = nn.Sequential(*layers)

# Fully connected layer producing forecast/backcast expansion coeffcients (waveform generator parameters).
# Fully connected layer producing forecast/backcast expansion coefficients (waveform generator parameters).
# The coefficients are emitted for each parameter of the likelihood for the forecast.
self.backcast_linear_layer = nn.Linear(
in_features=layer_width, out_features=n_theta_backcast
Expand Down Expand Up @@ -413,7 +413,7 @@ def __init__(
self.stacks = nn.ModuleList(self.stacks_list)

# setting the last backcast "branch" to be not trainable (without next block/stack, it doesn't need to be
# backpropagated). Removing this lines would cause logtensorboard to crash, since no gradient is stored
# backpropagated). Removing this line would cause logtensorboard to crash, since no gradient is stored
# on this params (the last block backcast is not part of the final output of the net).
self.stacks_list[-1].blocks[-1].backcast_linear_layer.requires_grad_(False)

Expand Down Expand Up @@ -476,7 +476,7 @@ def __init__(
N-HiTS is similar to N-BEATS (implemented in :class:`NBEATSModel`),
but attempts to provide better performance at lower computational cost by introducing
multi-rate sampling of the inputs and mulit-scale interpolation of the outputs.
multi-rate sampling of the inputs and multi-scale interpolation of the outputs.
Similar to :class:`NBEATSModel`, in addition to the univariate version presented in the paper,
this implementation also supports multivariate series (and covariates) by flattening the model inputs
Expand All @@ -489,7 +489,7 @@ def __init__(
This parameter can be a tuple of tuples, of size (num_stacks x num_blocks), specifying the kernel
size for each block in each stack. If left to ``None``, some default values will be used based on
``input_chunk_length``.
Similarly, the multi-scale interpolation is controled by ``n_freq_downsample``, which gives the
Similarly, the multi-scale interpolation is controlled by ``n_freq_downsample``, which gives the
downsampling factors to be used in each block of each stack. If left to ``None``, some default
values will be used based on the ``output_chunk_length``.
Expand Down Expand Up @@ -545,7 +545,7 @@ def __init__(
The PyTorch optimizer class to be used. Default: ``torch.optim.Adam``.
optimizer_kwargs
Optionally, some keyword arguments for the PyTorch optimizer (e.g., ``{'lr': 1e-3}``
for specifying a learning rate). Otherwise the default values of the selected ``optimizer_cls``
for specifying a learning rate). Otherwise, the default values of the selected ``optimizer_cls``
will be used. Default: ``None``.
lr_scheduler_cls
Optionally, the PyTorch learning rate scheduler class to be used. Specifying ``None`` corresponds
Expand Down
41 changes: 36 additions & 5 deletions darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torch.nn import LSTM as _LSTM

from darts import TimeSeries
from darts.logging import get_logger, raise_if, raise_if_not
from darts.models.components import glu_variants
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.components import glu_variants, layer_norm_variants
from darts.models.components.glu_variants import GLU_FFN
from darts.models.forecasting.pl_forecasting_module import PLMixedCovariatesModule
from darts.models.forecasting.tft_submodels import (
Expand Down Expand Up @@ -55,6 +55,7 @@ def __init__(
categorical_embedding_sizes: Dict[str, Tuple[int, int]],
dropout: float,
add_relative_index: bool,
norm_type: str,
**kwargs,
):

Expand Down Expand Up @@ -102,6 +103,8 @@ def __init__(
likelihood
The likelihood model to be used for probabilistic forecasts. By default, the TFT uses
a ``QuantileRegression`` likelihood.
norm_type: str
The type of LayerNorm variant to use.
**kwargs
all parameters required for :class:`darts.model.forecasting_models.PLForecastingModule` base class.
"""
Expand All @@ -121,6 +124,13 @@ def __init__(
self.dropout = dropout
self.add_relative_index = add_relative_index

try:
self.layer_norm = getattr(layer_norm_variants, norm_type)
except AttributeError:
raise_log(
AttributeError("please provide a valid layer norm type"),
)

# initialize last batch size to check if new mask needs to be generated
self.batch_size_last = -1
self.attention_mask = None
Expand Down Expand Up @@ -173,6 +183,7 @@ def __init__(
prescalers=self.prescalers_linear,
single_variable_grns={},
context_size=None, # no context for static variables
layer_norm=self.layer_norm,
)

# variable selection for encoder and decoder
Expand All @@ -192,6 +203,7 @@ def __init__(
context_size=self.hidden_size,
prescalers=self.prescalers_linear,
single_variable_grns={},
layer_norm=self.layer_norm,
)

self.decoder_vsn = _VariableSelectionNetwork(
Expand All @@ -202,6 +214,7 @@ def __init__(
context_size=self.hidden_size,
prescalers=self.prescalers_linear,
single_variable_grns={},
layer_norm=self.layer_norm,
)

# static encoders
Expand All @@ -211,6 +224,7 @@ def __init__(
hidden_size=self.hidden_size,
output_size=self.hidden_size,
dropout=self.dropout,
layer_norm=self.layer_norm,
)

# for hidden state of the lstm
Expand All @@ -219,6 +233,7 @@ def __init__(
hidden_size=self.hidden_size,
output_size=self.hidden_size,
dropout=self.dropout,
layer_norm=self.layer_norm,
)

# for cell state of the lstm
Expand All @@ -227,6 +242,7 @@ def __init__(
hidden_size=self.hidden_size,
output_size=self.hidden_size,
dropout=self.dropout,
layer_norm=self.layer_norm,
)

# for post lstm static enrichment
Expand All @@ -235,6 +251,7 @@ def __init__(
hidden_size=self.hidden_size,
output_size=self.hidden_size,
dropout=self.dropout,
layer_norm=self.layer_norm,
)

# lstm encoder (history) and decoder (future) for local processing
Expand All @@ -255,7 +272,9 @@ def __init__(
)

# post lstm GateAddNorm
self.post_lstm_gan = _GateAddNorm(input_size=self.hidden_size, dropout=dropout)
self.post_lstm_gan = _GateAddNorm(
input_size=self.hidden_size, dropout=dropout, layer_norm=self.layer_norm
)

# static enrichment and processing past LSTM
self.static_enrichment_grn = _GatedResidualNetwork(
Expand All @@ -264,6 +283,7 @@ def __init__(
output_size=self.hidden_size,
dropout=self.dropout,
context_size=self.hidden_size,
layer_norm=self.layer_norm,
)

# attention for long-range processing
Expand All @@ -272,14 +292,17 @@ def __init__(
n_head=self.num_attention_heads,
dropout=self.dropout,
)
self.post_attn_gan = _GateAddNorm(self.hidden_size, dropout=self.dropout)
self.post_attn_gan = _GateAddNorm(
self.hidden_size, dropout=self.dropout, layer_norm=self.layer_norm
)

if self.feed_forward == "GatedResidualNetwork":
self.feed_forward_block = _GatedResidualNetwork(
self.hidden_size,
self.hidden_size,
self.hidden_size,
dropout=self.dropout,
layer_norm=self.layer_norm,
)
else:
raise_if_not(
Expand All @@ -293,7 +316,9 @@ def __init__(
)

# output processing -> no dropout at this late stage
self.pre_output_gan = _GateAddNorm(self.hidden_size, dropout=None)
self.pre_output_gan = _GateAddNorm(
self.hidden_size, dropout=None, layer_norm=self.layer_norm
)

self.output_layer = nn.Linear(self.hidden_size, self.n_targets * self.loss_size)

Expand Down Expand Up @@ -637,6 +662,7 @@ def __init__(
add_relative_index: bool = False,
loss_fn: Optional[nn.Module] = None,
likelihood: Optional[Likelihood] = None,
norm_type: str = "LayerNorm",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about renaming to layer_norm_variant ? To make it clear it touches the layer norm part of the transformer.

**kwargs,
):
"""Temporal Fusion Transformers (TFT) for Interpretable Time Series Forecasting.
Expand Down Expand Up @@ -706,6 +732,9 @@ def __init__(
likelihood
The likelihood model to be used for probabilistic forecasts. By default, the TFT uses
a ``QuantileRegression`` likelihood.
norm_type: str
The type of LayerNorm variant to use. Default: ``LayerNorm``. Options available are
["LayerNorm", "RMSNorm", "LayerNormNoBias"]
**kwargs
Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
Darts' :class:`TorchForecastingModel`.
Expand Down Expand Up @@ -862,6 +891,7 @@ def __init__(
)
self.add_relative_index = add_relative_index
self.output_dim: Optional[Tuple[int, int]] = None
self.norm_type = norm_type

def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Module:
"""
Expand Down Expand Up @@ -1049,6 +1079,7 @@ def _create_model(self, train_sample: MixedCovariatesTrainTensorType) -> nn.Modu
hidden_continuous_size=self.hidden_continuous_size,
categorical_embedding_sizes=self.categorical_embedding_sizes,
add_relative_index=self.add_relative_index,
norm_type=self.norm_type,
**self.pl_module_params,
)

Expand Down
Loading