Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PreLN to fsmt module #15747

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/transformers/models/fsmt/configuration_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class FSMTConfig(PretrainedConfig):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_pre_layernorm (`bool`, *optional*, defaults to False):
Use Pre-LayerNorm in the Transformer encoder.
decoder_pre_layernorm (`bool`, *optional*. defaults to False):
Use Pre-LayerNorm in the Transformer decoder.
activation_function (`str` or `Callable`, *optional*, defaults to `"relu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
Expand Down Expand Up @@ -140,10 +144,12 @@ def __init__(
encoder_layers=12,
encoder_attention_heads=16,
encoder_layerdrop=0.0,
encoder_pre_layernorm=False,
decoder_ffn_dim=4096,
decoder_layers=12,
decoder_attention_heads=16,
decoder_layerdrop=0.0,
decoder_pre_layernorm=False,
attention_dropout=0.0,
dropout=0.1,
activation_dropout=0.0,
Expand Down Expand Up @@ -171,10 +177,12 @@ def __init__(
self.encoder_layers = self.num_hidden_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.encoder_layerdrop = encoder_layerdrop
self.encoder_pre_layernorm = encoder_pre_layernorm
self.decoder_layerdrop = decoder_layerdrop
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.decoder_pre_layernorm = decoder_pre_layernorm
self.max_position_embeddings = max_position_embeddings
self.init_std = init_std # Normal(0, this parameter)
self.activation_function = activation_function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,12 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
"encoder_ffn_dim": args["encoder_ffn_embed_dim"],
"encoder_layerdrop": args["encoder_layerdrop"],
"encoder_layers": args["encoder_layers"],
"encoder_pre_layernorm": args.get("encoder_normalize_before", False),
"decoder_attention_heads": args["decoder_attention_heads"],
"decoder_ffn_dim": args["decoder_ffn_embed_dim"],
"decoder_layerdrop": args["decoder_layerdrop"],
"decoder_layers": args["decoder_layers"],
"decoder_pre_layernorm": args.get("decoder_normalize_before", False),
"bos_token_id": 0,
"pad_token_id": 1,
"eos_token_id": 2,
Expand Down
40 changes: 35 additions & 5 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def __init__(self, config: FSMTConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
self.pre_layernorm = config.encoder_pre_layernorm
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
Expand All @@ -418,6 +419,8 @@ def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=Fa
encoded output of shape *(seq_len, batch, embed_dim)*
"""
residual = x
if self.pre_layernorm:
x = self.self_attn_layer_norm(x)
x, attn_weights = self.self_attn(
query=x,
key=x,
Expand All @@ -427,15 +430,19 @@ def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=Fa
)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
if not self.pre_layernorm:
x = self.self_attn_layer_norm(x)

residual = x
if self.pre_layernorm:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.final_layer_norm(x)
if not self.pre_layernorm:
x = self.final_layer_norm(x)
return x, attn_weights


Expand All @@ -461,6 +468,9 @@ def __init__(self, config: FSMTConfig, embed_tokens):
self.layers = nn.ModuleList(
[EncoderLayer(config) for _ in range(config.encoder_layers)]
) # type: List[EncoderLayer]
self.pre_layernorm = config.encoder_pre_layernorm
if self.pre_layernorm:
self.layer_norm = nn.LayerNorm(embed_dim)

def forward(
self,
Expand Down Expand Up @@ -530,6 +540,9 @@ def forward(
if output_attentions:
all_attentions = all_attentions + (attn,)

if self.pre_layernorm:
x = self.layer_norm(x)

# T x B x C -> B x T x C
x = x.transpose(0, 1)

Expand All @@ -555,6 +568,7 @@ def __init__(self, config: FSMTConfig):
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout

self.pre_layernorm = config.decoder_pre_layernorm
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = Attention(
self.embed_dim,
Expand Down Expand Up @@ -584,6 +598,8 @@ def forward(
if layer_state is None:
layer_state = {}

if self.pre_layernorm:
x = self.self_attn_layer_norm(x)
# Self Attention
x, self_attn_weights = self.self_attn(
query=x,
Expand All @@ -596,11 +612,14 @@ def forward(
)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.self_attn_layer_norm(x)
if not self.pre_layernorm:
x = self.self_attn_layer_norm(x)

# Cross attention
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
if self.pre_layernorm:
x = self.encoder_attn_layer_norm(x)
x, cross_attn_weights = self.encoder_attn(
query=x,
key=encoder_hidden_states,
Expand All @@ -611,16 +630,20 @@ def forward(
)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.encoder_attn_layer_norm(x)
if not self.pre_layernorm:
x = self.encoder_attn_layer_norm(x)

# Fully Connected
residual = x
if self.pre_layernorm:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.final_layer_norm(x)
if not self.pre_layernorm:
x = self.final_layer_norm(x)
return (
x,
self_attn_weights,
Expand Down Expand Up @@ -653,6 +676,10 @@ def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding):
[DecoderLayer(config) for _ in range(config.decoder_layers)]
) # type: List[DecoderLayer]

self.pre_layernorm = config.decoder_pre_layernorm
if self.pre_layernorm:
self.layer_norm = nn.LayerNorm(embed_dim)

if is_deepspeed_zero3_enabled():
import deepspeed

Expand Down Expand Up @@ -772,6 +799,9 @@ def forward(
all_self_attns += (layer_self_attn,)
all_cross_attns += (layer_cross_attn,)

if self.pre_layernorm:
x = self.layer_norm(x)

# add hidden states from the last decoder layer
if output_hidden_states:
x = x.transpose(0, 1)
Expand Down