From e3ff165aa54c07c0371deb09671e3c7dd5666a99 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Fri, 23 Apr 2021 18:58:06 +0200 Subject: [PATCH] Fix cross-attention head mask for Torch encoder-decoder models (#10605) * Fix cross-attention head mask for Torch BART models * Fix head masking for cross-attention module for the following models: BART, Blenderbot, Blenderbot_small, M2M_100, Marian, MBart, Pegasus * Enable test_headmasking for M2M_100 model * Fix cross_head_mask for FSMT, LED and T5 * This commit fixes `head_mask` for cross-attention modules in the following models: FSMT, LED, T5 * It also contains some smaller changes in doc so that it is be perfectly clear the shape of `cross_head_mask` is the same as of `decoder_head_mask` * Update template * Fix template for BartForCausalLM * Fix cross_head_mask for Speech2Text models * Fix cross_head_mask in templates * Fix args order in BartForCausalLM template * Fix doc in BART templates * Make more explicit naming * `cross_head_mask` -> `cross_attn_head_mask` * `cross_layer_head_mask` -> `cross_attn_layer_head_mask` * Fix doc * make style quality * Fix speech2text docstring --- src/transformers/models/bart/modeling_bart.py | 84 +++++++------ .../models/blenderbot/modeling_blenderbot.py | 83 +++++++------ .../modeling_blenderbot_small.py | 79 +++++++----- src/transformers/models/fsmt/modeling_fsmt.py | 41 ++++--- src/transformers/models/led/modeling_led.py | 114 ++++++------------ .../models/m2m_100/modeling_m2m_100.py | 83 ++++++++++--- .../models/marian/modeling_marian.py | 80 ++++++------ .../models/mbart/modeling_mbart.py | 84 +++++++------ .../models/pegasus/modeling_pegasus.py | 80 ++++++------ .../speech_to_text/modeling_speech_to_text.py | 65 ++++++---- src/transformers/models/t5/modeling_t5.py | 31 +++-- ...ng_{{cookiecutter.lowercase_modelname}}.py | 79 ++++++------ tests/test_modeling_bart.py | 4 + tests/test_modeling_blenderbot.py | 4 + tests/test_modeling_blenderbot_small.py | 4 + tests/test_modeling_common.py | 9 +- tests/test_modeling_fsmt.py | 3 + tests/test_modeling_led.py | 4 + tests/test_modeling_m2m_100.py | 16 ++- tests/test_modeling_marian.py | 4 + tests/test_modeling_mbart.py | 4 + tests/test_modeling_pegasus.py | 4 + tests/test_modeling_speech_to_text.py | 17 ++- 23 files changed, 587 insertions(+), 389 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 9ab73683f6bb..523fda70b540 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -296,7 +296,7 @@ def forward( attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -368,7 +368,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - encoder_layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -382,9 +382,9 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. - encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of - size `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -419,7 +419,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -598,18 +598,25 @@ def __init_subclass__(self): If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more information on the default strategy. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -710,11 +717,11 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded @@ -875,7 +882,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -912,18 +919,18 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -993,11 +1000,12 @@ def forward( all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -1031,7 +1039,7 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, - encoder_head_mask[idx] if encoder_head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: @@ -1042,7 +1050,9 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1123,6 +1133,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1172,7 +1183,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1248,6 +1259,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1282,6 +1294,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1386,6 +1399,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, inputs_embeds=None, decoder_inputs_embeds=None, @@ -1416,6 +1430,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1496,6 +1511,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, start_positions=None, end_positions=None, @@ -1527,6 +1543,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1633,7 +1650,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, labels=None, @@ -1666,18 +1683,17 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -1734,7 +1750,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index d54461b107ce..dc18993c52fe 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -298,7 +298,7 @@ def forward( attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -371,7 +371,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - encoder_layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -385,9 +385,9 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. - encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of - size `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -423,7 +423,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -554,18 +554,25 @@ def dummy_inputs(self): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will also be used by default. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -666,11 +673,11 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded @@ -834,7 +841,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -871,18 +878,19 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, + 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -951,11 +959,12 @@ def forward( all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -989,7 +998,7 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, - encoder_head_mask[idx] if encoder_head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: @@ -1000,7 +1009,9 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1090,6 +1101,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1147,7 +1159,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1241,6 +1253,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1275,6 +1288,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1395,7 +1409,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, labels=None, @@ -1428,18 +1442,17 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -1496,7 +1509,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index aae0eb0c7e11..167f9a9bac7b 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -296,7 +296,7 @@ def forward( attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -369,7 +369,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - encoder_layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -383,9 +383,9 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. - encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of - size `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -420,7 +420,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -555,18 +555,25 @@ def dummy_inputs(self): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will also be used by default. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -667,11 +674,11 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded @@ -834,7 +841,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -871,18 +878,18 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -953,10 +960,12 @@ def forward( all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -990,7 +999,7 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, - encoder_head_mask[idx] if encoder_head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: @@ -1001,7 +1010,9 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1077,6 +1088,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1134,7 +1146,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1216,6 +1228,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1250,6 +1263,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1370,7 +1384,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, labels=None, @@ -1403,18 +1417,17 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -1471,7 +1484,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 8747babe4a63..8979cc7e549a 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -248,17 +248,25 @@ decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will also be used by default. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the heas is **masked**. - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a @@ -573,7 +581,7 @@ def forward( layer_state=None, causal_mask=None, layer_head_mask=None, - encoder_layer_head_mask=None, + cross_attn_layer_head_mask=None, decoder_padding_mask=None, output_attentions=False, ): @@ -604,7 +612,7 @@ def forward( key=encoder_hidden_states, key_padding_mask=encoder_attn_mask, layer_state=layer_state, # mutates layer state - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, output_attentions=output_attentions, ) x = F.dropout(x, p=self.dropout, training=self.training) @@ -666,7 +674,7 @@ def forward( decoder_padding_mask, decoder_causal_mask, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, use_cache=False, output_attentions=False, @@ -690,12 +698,11 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the heas is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. Returns: BaseModelOutputWithPast or tuple: @@ -732,10 +739,11 @@ def forward( next_decoder_cache = [] # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -756,7 +764,7 @@ def forward( layer_state=layer_state, causal_mask=decoder_causal_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), output_attentions=output_attentions, ) @@ -1009,6 +1017,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Tuple] = None, past_key_values=None, use_cache=None, @@ -1065,7 +1074,7 @@ def forward( decoder_padding_mask, decoder_causal_mask=causal_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, @@ -1143,6 +1152,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, labels=None, @@ -1173,6 +1183,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index eecfcc27f60f..b245c3250b50 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -901,7 +901,7 @@ def forward( attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. """ residual = hidden_states attn_outputs = self.self_attn( @@ -968,7 +968,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - encoder_layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -982,9 +982,9 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. - encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of - size `(config.encoder_attention_heads,)`. + `(decoder_attention_heads,)`. + cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of + size `(decoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function. @@ -1018,7 +1018,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -1199,17 +1199,6 @@ class LEDSeq2SeqModelOutput(ModelOutput): Global attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Those are the attention weights from every token with global attention to every token in the sequence. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. - - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. """ last_hidden_state: torch.FloatTensor = None @@ -1221,8 +1210,6 @@ class LEDSeq2SeqModelOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None - head_mask: Optional[torch.FloatTensor] = None - decoder_head_mask: Optional[torch.FloatTensor] = None @dataclass @@ -1278,17 +1265,6 @@ class LEDSeq2SeqLMOutput(ModelOutput): Global attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Those are the attention weights from every token with global attention to every token in the sequence. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. - - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. """ loss: Optional[torch.FloatTensor] = None @@ -1301,8 +1277,6 @@ class LEDSeq2SeqLMOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None - head_mask: Optional[torch.FloatTensor] = None - decoder_head_mask: Optional[torch.FloatTensor] = None @dataclass @@ -1358,17 +1332,6 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput): Global attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Those are the attention weights from every token with global attention to every token in the sequence. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. - - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. """ loss: Optional[torch.FloatTensor] = None @@ -1381,8 +1344,6 @@ class LEDSeq2SeqSequenceClassifierOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None - head_mask: Optional[torch.FloatTensor] = None - decoder_head_mask: Optional[torch.FloatTensor] = None @dataclass @@ -1440,17 +1401,6 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput): Global attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Those are the attention weights from every token with global attention to every token in the sequence. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. - - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. """ loss: Optional[torch.FloatTensor] = None @@ -1464,8 +1414,6 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput): encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_global_attentions: Optional[Tuple[torch.FloatTensor]] = None - head_mask: Optional[torch.FloatTensor] = None - decoder_head_mask: Optional[torch.FloatTensor] = None LED_START_DOCSTRING = r""" @@ -1547,17 +1495,24 @@ class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput): - 0 for local attention (a sliding window attention), - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the heas is **masked**. - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -1730,7 +1685,7 @@ def forward( - 0 for local attention (a sliding window attention), - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, @@ -1914,7 +1869,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -1961,18 +1916,17 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the heas is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -2052,11 +2006,12 @@ def forward( all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -2090,7 +2045,7 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, - encoder_head_mask[idx] if encoder_head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: @@ -2100,7 +2055,9 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -2180,6 +2137,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, global_attention_mask=None, past_key_values=None, @@ -2224,7 +2182,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -2306,6 +2264,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, global_attention_mask=None, past_key_values=None, @@ -2358,6 +2317,7 @@ def forward( global_attention_mask=global_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -2463,6 +2423,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, global_attention_mask=None, inputs_embeds=None, @@ -2495,6 +2456,7 @@ def forward( global_attention_mask=global_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -2571,6 +2533,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, global_attention_mask=None, start_positions=None, @@ -2604,6 +2567,7 @@ def forward( global_attention_mask=global_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 4e9ddb6e6b37..7fccb32c4d43 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -367,7 +367,7 @@ def forward( attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -440,7 +440,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - encoder_layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -454,9 +454,9 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. - encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of - size `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -492,7 +492,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -603,6 +603,24 @@ def _init_weights(self, module): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will also be used by default. + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -704,6 +722,12 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -841,7 +865,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -878,6 +902,19 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -955,11 +992,12 @@ def forward( all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -993,7 +1031,7 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, - encoder_head_mask[idx] if encoder_head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: @@ -1004,7 +1042,9 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1085,6 +1125,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1126,7 +1167,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1201,6 +1242,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1249,6 +1291,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1281,7 +1324,14 @@ def forward( ) def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, ): # cut decoder_input_ids if past is used if past is not None: @@ -1293,6 +1343,7 @@ def prepare_inputs_for_generation( "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 5535b85aee35..b393697e6218 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -313,7 +313,7 @@ def forward( attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -386,7 +386,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - encoder_layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -400,9 +400,9 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. - encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of - size `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -437,7 +437,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -567,18 +567,25 @@ def dummy_inputs(self): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will also be used by default. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -678,11 +685,11 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded @@ -842,7 +849,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -879,18 +886,18 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -959,11 +966,12 @@ def forward( all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -997,7 +1005,7 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, - encoder_head_mask[idx] if encoder_head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: @@ -1008,7 +1016,9 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1084,6 +1094,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1142,7 +1153,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1229,6 +1240,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1264,6 +1276,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1391,7 +1404,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, labels=None, @@ -1424,18 +1437,17 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -1492,7 +1504,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 51ad5500036b..e9719795d37c 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -303,7 +303,7 @@ def forward( attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -375,7 +375,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - encoder_layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -389,9 +389,9 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. - encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of - size `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -427,7 +427,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -595,18 +595,25 @@ def dummy_inputs(self): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will also be used by default. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -708,11 +715,11 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded @@ -877,7 +884,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -914,18 +921,18 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing + cross-attention on hidden heads. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -995,11 +1002,12 @@ def forward( all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -1033,7 +1041,7 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, - encoder_head_mask[idx] if encoder_head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: @@ -1044,7 +1052,9 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1127,6 +1137,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1173,7 +1184,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1254,6 +1265,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1287,6 +1299,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1384,6 +1397,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, inputs_embeds=None, decoder_inputs_embeds=None, @@ -1414,6 +1428,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1495,6 +1510,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, start_positions=None, end_positions=None, @@ -1526,6 +1542,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1634,7 +1651,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, labels=None, @@ -1667,18 +1684,17 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -1735,7 +1751,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index b21fac086318..096150e08ee7 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -313,7 +313,7 @@ def forward( attention_mask (:obj:`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -386,7 +386,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - encoder_layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -400,9 +400,9 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. - encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of - size `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -438,7 +438,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -566,18 +566,25 @@ def dummy_inputs(self): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will also be used by default. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0, + 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -679,11 +686,11 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded @@ -848,7 +855,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -885,18 +892,18 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing + cross-attention on hidden heads. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -965,11 +972,12 @@ def forward( all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -1003,7 +1011,7 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, - encoder_head_mask[idx] if encoder_head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: @@ -1014,7 +1022,9 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1092,6 +1102,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1150,7 +1161,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1232,6 +1243,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1267,6 +1279,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -1390,7 +1403,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, labels=None, @@ -1423,18 +1436,17 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -1491,7 +1503,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 51de32aeb093..4b61f5ae7dd3 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -451,7 +451,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - encoder_layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -465,9 +465,9 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size :obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - :obj:`(config.encoder_attention_heads,)`. - encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of - size :obj:`(config.encoder_attention_heads,)`. + :obj:`(encoder_attention_heads,)`. + cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -503,7 +503,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -623,19 +623,29 @@ def _get_subsampled_encoder_attn_mask(self, attention_mask): :obj:`past_key_values`). decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will - also be used by default. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + also be used by default. <<<<<<< HEAD + + If you want to change padding behavior, you should read + :func:`modeling_speech_to_text._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the + paper `__ for more information on the default strategy. + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -728,11 +738,11 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -884,7 +894,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -921,18 +931,18 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention on hidden heads. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -1001,12 +1011,12 @@ def forward( all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." - + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -1039,7 +1049,7 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, - encoder_head_mask[idx] if encoder_head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: @@ -1050,7 +1060,9 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1127,6 +1139,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, decoder_inputs_embeds=None, @@ -1166,7 +1179,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -1240,6 +1253,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, decoder_inputs_embeds=None, @@ -1296,6 +1310,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 9fe45d4e5dfd..013f291c5ba0 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -607,7 +607,7 @@ def forward( encoder_attention_mask=None, encoder_decoder_position_bias=None, layer_head_mask=None, - encoder_layer_head_mask=None, + cross_attn_layer_head_mask=None, past_key_value=None, use_cache=False, output_attentions=False, @@ -661,7 +661,7 @@ def forward( key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, query_length=query_length, use_cache=use_cache, @@ -846,7 +846,7 @@ def forward( encoder_attention_mask=None, inputs_embeds=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -913,7 +913,7 @@ def forward( # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) - encoder_head_mask = self.get_head_mask(encoder_head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -925,7 +925,7 @@ def forward( for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): layer_head_mask = head_mask[i] - encoder_layer_head_mask = encoder_head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) @@ -942,8 +942,8 @@ def forward( encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) if layer_head_mask is not None: layer_head_mask = layer_head_mask.to(hidden_states.device) - if encoder_layer_head_mask is not None: - encoder_layer_head_mask = encoder_layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -955,7 +955,7 @@ def forward( encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, - encoder_layer_head_mask=encoder_layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, @@ -1082,12 +1082,19 @@ def forward( - 0 indicates the head is **masked**. decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the self-attention modules. in the decoder Mask values selected in ``[0, + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a @@ -1263,6 +1270,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1338,7 +1346,7 @@ def forward( encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1451,6 +1459,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1551,7 +1560,7 @@ def forward( encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 005328b06d6c..1e6f833a21f0 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -1041,10 +1041,11 @@ def forward( attention_mask=None, token_type_ids=None, position_ids=None, - head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, past_key_values=None, labels=None, use_cache=None, @@ -1876,7 +1877,7 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - encoder_layer_head_mask: Optional[torch.Tensor] = None, + cross_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -1890,9 +1891,9 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size - `(config.encoder_attention_heads,)`. - encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of - size `(config.encoder_attention_heads,)`. + `(encoder_attention_heads,)`. + cross_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under @@ -1927,7 +1928,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - layer_head_mask=encoder_layer_head_mask, + layer_head_mask=cross_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -2070,18 +2071,24 @@ def dummy_inputs(self): If you want to change padding behavior, you should read :func:`modeling_{{cookiecutter.lowercase_modelname}}._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more information on the default strategy. - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -2211,10 +2218,11 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded @@ -2377,7 +2385,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -2414,18 +2422,17 @@ def forward( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -2493,12 +2500,12 @@ def forward( all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - assert head_mask.size()[0] == ( - len(self.layers) - ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." - + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: @@ -2529,7 +2536,7 @@ def custom_forward(*inputs): encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, - encoder_head_mask[idx] if encoder_head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, ) else: @@ -2540,7 +2547,7 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), + cross_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -2621,6 +2628,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -2662,7 +2670,7 @@ def forward( encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, - encoder_head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, @@ -2743,6 +2751,7 @@ def forward( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -2791,6 +2800,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -3124,7 +3134,7 @@ def forward( encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None, - encoder_head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, labels=None, @@ -3157,18 +3167,17 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. - encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): - Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention - on hidden heads. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - - 0 indicates the heas is **masked**. + - 0 indicates the head is **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up @@ -3225,7 +3234,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, - encoder_head_mask=encoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 33ccbfaa80e1..b8847efdc900 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -55,6 +55,7 @@ def prepare_bart_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) @@ -64,6 +65,8 @@ def prepare_bart_inputs_dict( head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, @@ -71,6 +74,7 @@ def prepare_bart_inputs_dict( "decoder_attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_blenderbot.py b/tests/test_modeling_blenderbot.py index bff8f1ee0e19..dfaa3cdc0a01 100644 --- a/tests/test_modeling_blenderbot.py +++ b/tests/test_modeling_blenderbot.py @@ -45,6 +45,7 @@ def prepare_blenderbot_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) @@ -54,6 +55,8 @@ def prepare_blenderbot_inputs_dict( head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, @@ -61,6 +64,7 @@ def prepare_blenderbot_inputs_dict( "decoder_attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_blenderbot_small.py b/tests/test_modeling_blenderbot_small.py index 0eea63c8c6d9..f5dc8c42076a 100644 --- a/tests/test_modeling_blenderbot_small.py +++ b/tests/test_modeling_blenderbot_small.py @@ -50,6 +50,7 @@ def prepare_blenderbot_small_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) @@ -59,6 +60,8 @@ def prepare_blenderbot_small_inputs_dict( head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, @@ -66,6 +69,7 @@ def prepare_blenderbot_small_inputs_dict( "decoder_attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 22b0d6609a3b..b82a8c56641d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -225,8 +225,8 @@ def test_forward_signature(self): "decoder_attention_mask", ] expected_arg_names.extend( - ["head_mask", "decoder_head_mask", "encoder_outputs"] - if "head_mask" and "decoder_head_mask" in arg_names + ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names else ["encoder_outputs"] ) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) @@ -492,6 +492,8 @@ def test_headmasking(self): arg_names = [*signature.parameters.keys()] if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model inputs["decoder_head_mask"] = head_mask + if "cross_attn_head_mask" in arg_names: + inputs["cross_attn_head_mask"] = head_mask outputs = model(**inputs, return_dict=True) # Test that we can get a gradient back for importance score computation @@ -523,6 +525,7 @@ def check_attentions_validity(attentions): if model.config.is_encoder_decoder: check_attentions_validity(outputs.encoder_attentions) check_attentions_validity(outputs.decoder_attentions) + check_attentions_validity(outputs.cross_attentions) else: check_attentions_validity(outputs.attentions) @@ -1093,7 +1096,7 @@ def test_multi_gpu_data_parallel_forward(self): # some params shouldn't be scattered by nn.DataParallel # so just remove them if they are present. - blacklist_non_batched_params = ["head_mask", "decoder_head_mask"] + blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"] for k in blacklist_non_batched_params: inputs_dict.pop(k, None) diff --git a/tests/test_modeling_fsmt.py b/tests/test_modeling_fsmt.py index 708ef1dc948e..4942fe7317cb 100644 --- a/tests/test_modeling_fsmt.py +++ b/tests/test_modeling_fsmt.py @@ -113,6 +113,7 @@ def prepare_fsmt_inputs_dict( attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) @@ -120,6 +121,8 @@ def prepare_fsmt_inputs_dict( head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "attention_mask": attention_mask, diff --git a/tests/test_modeling_led.py b/tests/test_modeling_led.py index caffe199bb2b..e507922762f1 100644 --- a/tests/test_modeling_led.py +++ b/tests/test_modeling_led.py @@ -52,6 +52,7 @@ def prepare_led_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) @@ -61,6 +62,8 @@ def prepare_led_inputs_dict( head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, @@ -68,6 +71,7 @@ def prepare_led_inputs_dict( "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_m2m_100.py b/tests/test_modeling_m2m_100.py index db5aff1eb2a2..e39876e4ee7c 100644 --- a/tests/test_modeling_m2m_100.py +++ b/tests/test_modeling_m2m_100.py @@ -41,16 +41,28 @@ def prepare_m2m_100_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) + if decoder_head_mask is None: + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } @@ -142,9 +154,10 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): model = M2M100Model(config=config).get_decoder().to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] + head_mask = inputs_dict["head_mask"] # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() @@ -217,7 +230,6 @@ class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False - test_head_masking = False test_missing_keys = False def setUp(self): diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index 191a48af8bca..7b6cb153065b 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -60,6 +60,7 @@ def prepare_marian_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) @@ -69,6 +70,8 @@ def prepare_marian_inputs_dict( head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, @@ -76,6 +79,7 @@ def prepare_marian_inputs_dict( "decoder_attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 9428acb479a5..e5baa4f30a7c 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -52,6 +52,7 @@ def prepare_mbart_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) @@ -61,6 +62,8 @@ def prepare_mbart_inputs_dict( head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, @@ -68,6 +71,7 @@ def prepare_mbart_inputs_dict( "decoder_attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index c0418d8c8954..4106793332d6 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -42,6 +42,7 @@ def prepare_pegasus_inputs_dict( decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) @@ -51,6 +52,8 @@ def prepare_pegasus_inputs_dict( head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, @@ -58,6 +61,7 @@ def prepare_pegasus_inputs_dict( "decoder_attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } diff --git a/tests/test_modeling_speech_to_text.py b/tests/test_modeling_speech_to_text.py index c5b7db53c854..102a33f4a38f 100644 --- a/tests/test_modeling_speech_to_text.py +++ b/tests/test_modeling_speech_to_text.py @@ -55,17 +55,29 @@ def prepare_speech_to_text_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_features.ne(0) if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + if head_mask is None: + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) + if decoder_head_mask is None: + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) + if cross_attn_head_mask is None: + cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { # "input_ids": input_features, "input_features": input_features, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, } @@ -247,7 +259,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False - test_head_masking = False test_missing_keys = False test_torchscript = True @@ -316,8 +327,8 @@ def test_forward_signature(self): "decoder_attention_mask", ] expected_arg_names.extend( - ["head_mask", "decoder_head_mask", "encoder_outputs"] - if "head_mask" and "decoder_head_mask" in arg_names + ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names else ["encoder_outputs"] ) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)