From bb596852083580ed97c23d1e090ec1f4455882cc Mon Sep 17 00:00:00 2001 From: jinmang2 Date: Mon, 21 Feb 2022 04:30:36 +0000 Subject: [PATCH 1/2] Add PreLN to fsmt module --- .../models/fsmt/configuration_fsmt.py | 8 ++++ ..._original_pytorch_checkpoint_to_pytorch.py | 2 + src/transformers/models/fsmt/modeling_fsmt.py | 41 +++++++++++++++++-- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/fsmt/configuration_fsmt.py b/src/transformers/models/fsmt/configuration_fsmt.py index 290b75af5ed3..bd9095f53e48 100644 --- a/src/transformers/models/fsmt/configuration_fsmt.py +++ b/src/transformers/models/fsmt/configuration_fsmt.py @@ -69,6 +69,10 @@ class FSMTConfig(PretrainedConfig): Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. encoder_ffn_dim (`int`, *optional*, defaults to 4096): Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_normalize_before (`bool`, *optional*, defaults to False): + Use Pre-LayerNorm in the Transformer encoder. + decoder_normalize_before (`bool`, *optional*. defaults to False): + Use Pre-LayerNorm in the Transformer decoder. activation_function (`str` or `Callable`, *optional*, defaults to `"relu"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"silu"` and `"gelu_new"` are supported. @@ -140,10 +144,12 @@ def __init__( encoder_layers=12, encoder_attention_heads=16, encoder_layerdrop=0.0, + encoder_normalize_before=False, decoder_ffn_dim=4096, decoder_layers=12, decoder_attention_heads=16, decoder_layerdrop=0.0, + decoder_normalize_before=False, attention_dropout=0.0, dropout=0.1, activation_dropout=0.0, @@ -171,10 +177,12 @@ def __init__( self.encoder_layers = self.num_hidden_layers = encoder_layers self.encoder_attention_heads = encoder_attention_heads self.encoder_layerdrop = encoder_layerdrop + self.encoder_normalize_before = encoder_normalize_before self.decoder_layerdrop = decoder_layerdrop self.decoder_ffn_dim = decoder_ffn_dim self.decoder_layers = decoder_layers self.decoder_attention_heads = decoder_attention_heads + self.decoder_normalize_before = decoder_normalize_before self.max_position_embeddings = max_position_embeddings self.init_std = init_std # Normal(0, this parameter) self.activation_function = activation_function diff --git a/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py index 2470492ac743..6e76490d1506 100755 --- a/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py @@ -188,10 +188,12 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder "encoder_ffn_dim": args["encoder_ffn_embed_dim"], "encoder_layerdrop": args["encoder_layerdrop"], "encoder_layers": args["encoder_layers"], + "encoder_normalize_before": args["encoder_normalize_before"], "decoder_attention_heads": args["decoder_attention_heads"], "decoder_ffn_dim": args["decoder_ffn_embed_dim"], "decoder_layerdrop": args["decoder_layerdrop"], "decoder_layers": args["decoder_layers"], + "decoder_normalize_before": args["decoder_normalize_before"], "bos_token_id": 0, "pad_token_id": 1, "eos_token_id": 2, diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 2efc46e6d1bf..63f989436328 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -395,6 +395,7 @@ def __init__(self, config: FSMTConfig): super().__init__() self.embed_dim = config.d_model self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout) + self.normalize_before = config.encoder_normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -418,6 +419,8 @@ def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=Fa encoded output of shape *(seq_len, batch, embed_dim)* """ residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) x, attn_weights = self.self_attn( query=x, key=x, @@ -427,15 +430,19 @@ def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=Fa ) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x - x = self.self_attn_layer_norm(x) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) residual = x + if self.normalize_before: + x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x - x = self.final_layer_norm(x) + if not self.normalize_before: + x = self.final_layer_norm(x) return x, attn_weights @@ -461,6 +468,10 @@ def __init__(self, config: FSMTConfig, embed_tokens): self.layers = nn.ModuleList( [EncoderLayer(config) for _ in range(config.encoder_layers)] ) # type: List[EncoderLayer] + if config.encoder_normalize_before: + self.layer_norm = nn.LayerNorm(embed_dim) + else: + self.layer_norm = None def forward( self, @@ -530,6 +541,9 @@ def forward( if output_attentions: all_attentions = all_attentions + (attn,) + if self.layer_norm is not None: + x = self.layer_norm(x) + # T x B x C -> B x T x C x = x.transpose(0, 1) @@ -555,6 +569,7 @@ def __init__(self, config: FSMTConfig): self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout + self.normalize_before = config.decoder_normalize_before self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.encoder_attn = Attention( self.embed_dim, @@ -584,6 +599,8 @@ def forward( if layer_state is None: layer_state = {} + if self.normalize_before: + x = self.self_attn_layer_norm(x) # Self Attention x, self_attn_weights = self.self_attn( query=x, @@ -597,10 +614,14 @@ def forward( x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.self_attn_layer_norm(x) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) # Cross attention residual = x assert self.encoder_attn.cache_key != self.self_attn.cache_key + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) x, cross_attn_weights = self.encoder_attn( query=x, key=encoder_hidden_states, @@ -611,16 +632,20 @@ def forward( ) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x - x = self.encoder_attn_layer_norm(x) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) # Fully Connected residual = x + if self.normalize_before: + x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x - x = self.final_layer_norm(x) + if not self.normalize_before: + x = self.final_layer_norm(x) return ( x, self_attn_weights, @@ -653,6 +678,11 @@ def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding): [DecoderLayer(config) for _ in range(config.decoder_layers)] ) # type: List[DecoderLayer] + if config.decoder_normalize_before: + self.layer_norm = nn.LayerNorm(embed_dim) + else: + self.layer_norm = None + if is_deepspeed_zero3_enabled(): import deepspeed @@ -772,6 +802,9 @@ def forward( all_self_attns += (layer_self_attn,) all_cross_attns += (layer_cross_attn,) + if self.layer_norm is not None: + x = self.layer_norm(x) + # add hidden states from the last decoder layer if output_hidden_states: x = x.transpose(0, 1) From 298240385f570cae0616cf502b36b8bd909336d7 Mon Sep 17 00:00:00 2001 From: jinmang2 Date: Mon, 28 Feb 2022 01:17:29 +0000 Subject: [PATCH 2/2] fix names of pre-layernorm options and bugs --- .../models/fsmt/configuration_fsmt.py | 12 +++--- ..._original_pytorch_checkpoint_to_pytorch.py | 4 +- src/transformers/models/fsmt/modeling_fsmt.py | 39 +++++++++---------- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/fsmt/configuration_fsmt.py b/src/transformers/models/fsmt/configuration_fsmt.py index bd9095f53e48..bff04ffa94bd 100644 --- a/src/transformers/models/fsmt/configuration_fsmt.py +++ b/src/transformers/models/fsmt/configuration_fsmt.py @@ -69,9 +69,9 @@ class FSMTConfig(PretrainedConfig): Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. encoder_ffn_dim (`int`, *optional*, defaults to 4096): Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. - encoder_normalize_before (`bool`, *optional*, defaults to False): + encoder_pre_layernorm (`bool`, *optional*, defaults to False): Use Pre-LayerNorm in the Transformer encoder. - decoder_normalize_before (`bool`, *optional*. defaults to False): + decoder_pre_layernorm (`bool`, *optional*. defaults to False): Use Pre-LayerNorm in the Transformer decoder. activation_function (`str` or `Callable`, *optional*, defaults to `"relu"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, @@ -144,12 +144,12 @@ def __init__( encoder_layers=12, encoder_attention_heads=16, encoder_layerdrop=0.0, - encoder_normalize_before=False, + encoder_pre_layernorm=False, decoder_ffn_dim=4096, decoder_layers=12, decoder_attention_heads=16, decoder_layerdrop=0.0, - decoder_normalize_before=False, + decoder_pre_layernorm=False, attention_dropout=0.0, dropout=0.1, activation_dropout=0.0, @@ -177,12 +177,12 @@ def __init__( self.encoder_layers = self.num_hidden_layers = encoder_layers self.encoder_attention_heads = encoder_attention_heads self.encoder_layerdrop = encoder_layerdrop - self.encoder_normalize_before = encoder_normalize_before + self.encoder_pre_layernorm = encoder_pre_layernorm self.decoder_layerdrop = decoder_layerdrop self.decoder_ffn_dim = decoder_ffn_dim self.decoder_layers = decoder_layers self.decoder_attention_heads = decoder_attention_heads - self.decoder_normalize_before = decoder_normalize_before + self.decoder_pre_layernorm = decoder_pre_layernorm self.max_position_embeddings = max_position_embeddings self.init_std = init_std # Normal(0, this parameter) self.activation_function = activation_function diff --git a/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py index 6e76490d1506..0a15937efd41 100755 --- a/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py @@ -188,12 +188,12 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder "encoder_ffn_dim": args["encoder_ffn_embed_dim"], "encoder_layerdrop": args["encoder_layerdrop"], "encoder_layers": args["encoder_layers"], - "encoder_normalize_before": args["encoder_normalize_before"], + "encoder_pre_layernorm": args.get("encoder_normalize_before", False), "decoder_attention_heads": args["decoder_attention_heads"], "decoder_ffn_dim": args["decoder_ffn_embed_dim"], "decoder_layerdrop": args["decoder_layerdrop"], "decoder_layers": args["decoder_layers"], - "decoder_normalize_before": args["decoder_normalize_before"], + "decoder_pre_layernorm": args.get("decoder_normalize_before", False), "bos_token_id": 0, "pad_token_id": 1, "eos_token_id": 2, diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 63f989436328..10ffa29c0bd0 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -395,7 +395,7 @@ def __init__(self, config: FSMTConfig): super().__init__() self.embed_dim = config.d_model self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout) - self.normalize_before = config.encoder_normalize_before + self.pre_layernorm = config.encoder_pre_layernorm self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -419,7 +419,7 @@ def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=Fa encoded output of shape *(seq_len, batch, embed_dim)* """ residual = x - if self.normalize_before: + if self.pre_layernorm: x = self.self_attn_layer_norm(x) x, attn_weights = self.self_attn( query=x, @@ -430,18 +430,18 @@ def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=Fa ) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x - if not self.normalize_before: + if not self.pre_layernorm: x = self.self_attn_layer_norm(x) residual = x - if self.normalize_before: + if self.pre_layernorm: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x - if not self.normalize_before: + if not self.pre_layernorm: x = self.final_layer_norm(x) return x, attn_weights @@ -468,10 +468,9 @@ def __init__(self, config: FSMTConfig, embed_tokens): self.layers = nn.ModuleList( [EncoderLayer(config) for _ in range(config.encoder_layers)] ) # type: List[EncoderLayer] - if config.encoder_normalize_before: + self.pre_layernorm = config.encoder_pre_layernorm + if self.pre_layernorm: self.layer_norm = nn.LayerNorm(embed_dim) - else: - self.layer_norm = None def forward( self, @@ -541,7 +540,7 @@ def forward( if output_attentions: all_attentions = all_attentions + (attn,) - if self.layer_norm is not None: + if self.pre_layernorm: x = self.layer_norm(x) # T x B x C -> B x T x C @@ -569,7 +568,7 @@ def __init__(self, config: FSMTConfig): self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout - self.normalize_before = config.decoder_normalize_before + self.pre_layernorm = config.decoder_pre_layernorm self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.encoder_attn = Attention( self.embed_dim, @@ -599,7 +598,7 @@ def forward( if layer_state is None: layer_state = {} - if self.normalize_before: + if self.pre_layernorm: x = self.self_attn_layer_norm(x) # Self Attention x, self_attn_weights = self.self_attn( @@ -613,14 +612,13 @@ def forward( ) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x - x = self.self_attn_layer_norm(x) - if not self.normalize_before: + if not self.pre_layernorm: x = self.self_attn_layer_norm(x) # Cross attention residual = x assert self.encoder_attn.cache_key != self.self_attn.cache_key - if self.normalize_before: + if self.pre_layernorm: x = self.encoder_attn_layer_norm(x) x, cross_attn_weights = self.encoder_attn( query=x, @@ -632,19 +630,19 @@ def forward( ) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x - if not self.normalize_before: + if not self.pre_layernorm: x = self.encoder_attn_layer_norm(x) # Fully Connected residual = x - if self.normalize_before: + if self.pre_layernorm: x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training) x = self.fc2(x) x = nn.functional.dropout(x, p=self.dropout, training=self.training) x = residual + x - if not self.normalize_before: + if not self.pre_layernorm: x = self.final_layer_norm(x) return ( x, @@ -678,10 +676,9 @@ def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding): [DecoderLayer(config) for _ in range(config.decoder_layers)] ) # type: List[DecoderLayer] - if config.decoder_normalize_before: + self.pre_layernorm = config.decoder_pre_layernorm + if self.pre_layernorm: self.layer_norm = nn.LayerNorm(embed_dim) - else: - self.layer_norm = None if is_deepspeed_zero3_enabled(): import deepspeed @@ -802,7 +799,7 @@ def forward( all_self_attns += (layer_self_attn,) all_cross_attns += (layer_cross_attn,) - if self.layer_norm is not None: + if self.pre_layernorm: x = self.layer_norm(x) # add hidden states from the last decoder layer