diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index cdb7e53ef961..ad72b06bd0b3 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -164,6 +164,7 @@ def call( key_value_states: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, training=False, ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: """Input shape: Batch x Time x Channel""" @@ -230,6 +231,17 @@ def call( attn_weights = tf.nn.softmax(attn_weights, axis=-1) + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + attn_probs = self.dropout(attn_weights, training=training) attn_output = tf.matmul(attn_probs, value_states) @@ -266,16 +278,18 @@ def __init__(self, config: BartConfig, **kwargs): self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): + def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): """ Args: hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`tf.Tensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)` """ residual = hidden_states hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask + hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) tf.debugging.assert_equal( shape_list(hidden_states), @@ -331,6 +345,8 @@ def call( attention_mask: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, + encoder_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -342,6 +358,10 @@ def call( encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(decoder_attention_heads,)` + encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size + `(encoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -354,6 +374,7 @@ def call( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, ) hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -370,6 +391,7 @@ def call( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -527,6 +549,18 @@ def serving(self, inputs): the right for denoising pre-training following the paper. decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -593,6 +627,7 @@ def call( input_ids=None, inputs_embeds=None, attention_mask=None, + head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -617,6 +652,12 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -635,6 +676,7 @@ def call( config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -670,8 +712,15 @@ def call( encoder_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) # encoder layers - for encoder_layer in self.layers: + for idx, encoder_layer in enumerate(self.layers): if inputs["output_hidden_states"]: encoder_states = encoder_states + (hidden_states,) @@ -680,7 +729,11 @@ def call( if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer continue - hidden_states, attn = encoder_layer(hidden_states, attention_mask) + hidden_states, attn = encoder_layer( + hidden_states, + attention_mask, + inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + ) if inputs["output_attentions"]: all_attentions += (attn,) @@ -737,6 +790,8 @@ def call( attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -774,6 +829,19 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -802,6 +870,8 @@ def call( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -858,6 +928,13 @@ def call( all_self_attns = () if inputs["output_attentions"] else None present_key_values = () if inputs["use_cache"] else None + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: @@ -875,6 +952,10 @@ def call( attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], + layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + encoder_layer_head_mask=inputs["encoder_head_mask"][idx] + if inputs["encoder_head_mask"] is not None + else None, past_key_value=past_key_value, ) @@ -945,6 +1026,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -963,6 +1046,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -993,6 +1078,7 @@ def call( inputs["encoder_outputs"] = self.encoder( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], @@ -1015,6 +1101,8 @@ def call( attention_mask=inputs["decoder_attention_mask"], encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], + head_mask=inputs["decoder_head_mask"], + encoder_head_mask=inputs["head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1067,6 +1155,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1085,6 +1175,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1102,6 +1194,8 @@ def call( attention_mask=inputs["attention_mask"], decoder_input_ids=inputs["decoder_input_ids"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1179,6 +1273,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1207,6 +1303,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1233,6 +1331,8 @@ def call( decoder_input_ids=inputs["decoder_input_ids"], encoder_outputs=inputs["encoder_outputs"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -1277,7 +1377,15 @@ def serving_output(self, output): encoder_attentions=enc_attns, ) - def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past, + attention_mask, + head_mask=None, + use_cache=None, + **kwargs, + ) -> Dict: assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" if len(past) == 1: assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" @@ -1309,6 +1417,7 @@ def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 350fa139589c..669c211c5601 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -167,6 +167,7 @@ def call( key_value_states: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, training=False, ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: """Input shape: Batch x Time x Channel""" @@ -233,6 +234,17 @@ def call( attn_weights = tf.nn.softmax(attn_weights, axis=-1) + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + attn_probs = self.dropout(attn_weights, training=training) attn_output = tf.matmul(attn_probs, value_states) @@ -270,17 +282,19 @@ def __init__(self, config: BlenderbotConfig, **kwargs): self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): + def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): """ Args: hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`tf.Tensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)` """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask + hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) tf.debugging.assert_equal( shape_list(hidden_states), @@ -336,6 +350,8 @@ def call( attention_mask: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, + encoder_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -347,6 +363,10 @@ def call( encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(decoder_attention_heads,)` + encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size + `(encoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -360,6 +380,7 @@ def call( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, ) hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -376,6 +397,7 @@ def call( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -524,6 +546,18 @@ def serving(self, inputs): :obj:`past_key_values`). decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -590,6 +624,7 @@ def call( input_ids=None, inputs_embeds=None, attention_mask=None, + head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -614,6 +649,12 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -632,6 +673,7 @@ def call( config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -666,8 +708,15 @@ def call( encoder_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) # encoder layers - for encoder_layer in self.layers: + for idx, encoder_layer in enumerate(self.layers): if inputs["output_hidden_states"]: encoder_states = encoder_states + (hidden_states,) @@ -676,7 +725,11 @@ def call( if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer continue - hidden_states, attn = encoder_layer(hidden_states, attention_mask) + hidden_states, attn = encoder_layer( + hidden_states, + attention_mask, + inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + ) if inputs["output_attentions"]: all_attentions += (attn,) @@ -735,6 +788,8 @@ def call( attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -772,6 +827,19 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -800,6 +868,8 @@ def call( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -855,6 +925,14 @@ def call( all_hidden_states = () all_self_attns = () present_key_values = () + + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: @@ -871,6 +949,10 @@ def call( attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], + layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + encoder_layer_head_mask=inputs["encoder_head_mask"][idx] + if inputs["encoder_head_mask"] is not None + else None, past_key_value=past_key_value, ) @@ -943,6 +1025,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -961,6 +1045,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -983,6 +1069,7 @@ def call( inputs["encoder_outputs"] = self.encoder( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], @@ -1005,6 +1092,8 @@ def call( attention_mask=inputs["decoder_attention_mask"], encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], + head_mask=inputs["decoder_head_mask"], + encoder_head_mask=inputs["head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1070,6 +1159,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1088,6 +1179,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1105,6 +1198,8 @@ def call( attention_mask=inputs["attention_mask"], decoder_input_ids=inputs["decoder_input_ids"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1196,6 +1291,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1224,6 +1321,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1249,6 +1348,8 @@ def call( decoder_input_ids=inputs["decoder_input_ids"], encoder_outputs=inputs["encoder_outputs"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -1295,7 +1396,15 @@ def serving_output(self, output): ) # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past, + attention_mask, + head_mask=None, + use_cache=None, + **kwargs, + ) -> Dict: assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" if len(past) == 1: assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" @@ -1327,6 +1436,7 @@ def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 99ce57641357..710697f58adf 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -166,6 +166,7 @@ def call( key_value_states: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, training=False, ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: """Input shape: Batch x Time x Channel""" @@ -232,6 +233,17 @@ def call( attn_weights = tf.nn.softmax(attn_weights, axis=-1) + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + attn_probs = self.dropout(attn_weights, training=training) attn_output = tf.matmul(attn_probs, value_states) @@ -269,16 +281,18 @@ def __init__(self, config: BlenderbotSmallConfig, **kwargs): self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): + def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): """ Args: hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`tf.Tensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)` """ residual = hidden_states hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask + hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) tf.debugging.assert_equal( shape_list(hidden_states), @@ -335,6 +349,8 @@ def call( attention_mask: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, + encoder_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -346,6 +362,10 @@ def call( encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(decoder_attention_heads,)` + encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size + `(encoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -358,6 +378,7 @@ def call( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, ) hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -374,6 +395,7 @@ def call( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -529,6 +551,18 @@ def serving(self, inputs): :obj:`past_key_values`). decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -595,6 +629,7 @@ def call( input_ids=None, inputs_embeds=None, attention_mask=None, + head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -619,6 +654,12 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -637,6 +678,7 @@ def call( config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -672,8 +714,15 @@ def call( encoder_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) # encoder layers - for encoder_layer in self.layers: + for idx, encoder_layer in enumerate(self.layers): if inputs["output_hidden_states"]: encoder_states = encoder_states + (hidden_states,) @@ -682,7 +731,11 @@ def call( if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer continue - hidden_states, attn = encoder_layer(hidden_states, attention_mask) + hidden_states, attn = encoder_layer( + hidden_states, + attention_mask, + inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + ) if inputs["output_attentions"]: all_attentions += (attn,) @@ -740,6 +793,8 @@ def call( attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -777,6 +832,19 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -805,6 +873,8 @@ def call( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -859,6 +929,13 @@ def call( all_self_attns = () present_key_values = () + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: @@ -875,6 +952,10 @@ def call( attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], + layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + encoder_layer_head_mask=inputs["encoder_head_mask"][idx] + if inputs["encoder_head_mask"] is not None + else None, past_key_value=past_key_value, ) @@ -945,6 +1026,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -963,6 +1046,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -985,6 +1070,7 @@ def call( inputs["encoder_outputs"] = self.encoder( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], @@ -1007,6 +1093,8 @@ def call( attention_mask=inputs["decoder_attention_mask"], encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], + head_mask=inputs["decoder_head_mask"], + encoder_head_mask=inputs["head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1059,6 +1147,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1077,6 +1167,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1094,6 +1186,8 @@ def call( attention_mask=inputs["attention_mask"], decoder_input_ids=inputs["decoder_input_ids"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1172,6 +1266,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1200,6 +1296,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1225,6 +1323,8 @@ def call( decoder_input_ids=inputs["decoder_input_ids"], encoder_outputs=inputs["encoder_outputs"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -1271,7 +1371,15 @@ def serving_output(self, output): ) # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past, + attention_mask, + head_mask=None, + use_cache=None, + **kwargs, + ) -> Dict: assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" if len(past) == 1: assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" @@ -1303,6 +1411,7 @@ def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index d402a3503cb3..78fac03a5d39 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -196,6 +196,7 @@ def call( key_value_states: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, training=False, ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: """Input shape: Batch x Time x Channel""" @@ -262,6 +263,17 @@ def call( attn_weights = tf.nn.softmax(attn_weights, axis=-1) + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + attn_probs = self.dropout(attn_weights, training=training) attn_output = tf.matmul(attn_probs, value_states) @@ -299,16 +311,18 @@ def __init__(self, config: MarianConfig, **kwargs): self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): + def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): """ Args: hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`tf.Tensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)` """ residual = hidden_states hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask + hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) tf.debugging.assert_equal( shape_list(hidden_states), @@ -365,6 +379,8 @@ def call( attention_mask: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, + encoder_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -376,6 +392,10 @@ def call( encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(decoder_attention_heads,)` + encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size + `(encoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -388,6 +408,7 @@ def call( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, ) hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -404,6 +425,7 @@ def call( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -548,6 +570,18 @@ def serving(self, inputs): :obj:`past_key_values`). decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -612,6 +646,7 @@ def call( input_ids=None, inputs_embeds=None, attention_mask=None, + head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -636,6 +671,12 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -654,6 +695,7 @@ def call( config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -688,8 +730,15 @@ def call( encoder_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) # encoder layers - for encoder_layer in self.layers: + for idx, encoder_layer in enumerate(self.layers): if inputs["output_hidden_states"]: encoder_states = encoder_states + (hidden_states,) @@ -698,7 +747,11 @@ def call( if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer continue - hidden_states, attn = encoder_layer(hidden_states, attention_mask) + hidden_states, attn = encoder_layer( + hidden_states, + attention_mask, + inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + ) if inputs["output_attentions"]: all_attentions += (attn,) @@ -753,6 +806,8 @@ def call( attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -790,6 +845,19 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -818,6 +886,8 @@ def call( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -872,6 +942,14 @@ def call( all_hidden_states = () all_self_attns = () present_key_values = () + + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: @@ -888,6 +966,10 @@ def call( attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], + layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + encoder_layer_head_mask=inputs["encoder_head_mask"][idx] + if inputs["encoder_head_mask"] is not None + else None, past_key_value=past_key_value, ) @@ -958,6 +1040,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -976,6 +1060,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1001,6 +1087,7 @@ def call( inputs["encoder_outputs"] = self.encoder( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], @@ -1023,6 +1110,8 @@ def call( attention_mask=inputs["decoder_attention_mask"], encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], + head_mask=inputs["decoder_head_mask"], + encoder_head_mask=inputs["head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1075,6 +1164,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1092,6 +1183,8 @@ def call( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, @@ -1110,6 +1203,8 @@ def call( attention_mask=inputs["attention_mask"], decoder_input_ids=inputs["decoder_input_ids"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1188,6 +1283,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1216,6 +1313,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1242,6 +1341,8 @@ def call( decoder_input_ids=inputs["decoder_input_ids"], encoder_outputs=inputs["encoder_outputs"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -1288,7 +1389,15 @@ def serving_output(self, output): ) # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past, + attention_mask, + head_mask=None, + use_cache=None, + **kwargs, + ) -> Dict: assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" if len(past) == 1: assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" @@ -1320,6 +1429,7 @@ def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index e5634303e58d..6f77248f70a5 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -170,6 +170,7 @@ def call( key_value_states: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, training=False, ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: """Input shape: Batch x Time x Channel""" @@ -236,6 +237,17 @@ def call( attn_weights = tf.nn.softmax(attn_weights, axis=-1) + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + attn_probs = self.dropout(attn_weights, training=training) attn_output = tf.matmul(attn_probs, value_states) @@ -272,17 +284,19 @@ def __init__(self, config: MBartConfig, **kwargs): self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): + def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): """ Args: hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`tf.Tensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)` """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask + hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) tf.debugging.assert_equal( shape_list(hidden_states), @@ -337,6 +351,8 @@ def call( attention_mask: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, + encoder_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -348,6 +364,10 @@ def call( encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(decoder_attention_heads,)` + encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size + `(encoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -361,6 +381,7 @@ def call( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, ) hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -377,6 +398,7 @@ def call( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -505,6 +527,18 @@ def serving(self, inputs): the right for denoising pre-training following the paper. decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -601,6 +635,7 @@ def call( input_ids=None, inputs_embeds=None, attention_mask=None, + head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -625,6 +660,12 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -643,6 +684,7 @@ def call( config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -678,8 +720,15 @@ def call( encoder_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) # encoder layers - for encoder_layer in self.layers: + for idx, encoder_layer in enumerate(self.layers): if inputs["output_hidden_states"]: encoder_states = encoder_states + (hidden_states,) @@ -688,7 +737,11 @@ def call( if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer continue - hidden_states, attn = encoder_layer(hidden_states, attention_mask) + hidden_states, attn = encoder_layer( + hidden_states, + attention_mask, + inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + ) if inputs["output_attentions"]: all_attentions += (attn,) @@ -748,6 +801,8 @@ def call( attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -785,6 +840,19 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -813,6 +881,8 @@ def call( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -868,6 +938,14 @@ def call( all_hidden_states = () all_self_attns = () present_key_values = () + + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: @@ -884,6 +962,10 @@ def call( attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], + layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + encoder_layer_head_mask=inputs["encoder_head_mask"][idx] + if inputs["encoder_head_mask"] is not None + else None, past_key_value=past_key_value, ) @@ -956,6 +1038,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -974,6 +1058,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1002,6 +1088,7 @@ def call( inputs["encoder_outputs"] = self.encoder( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], @@ -1024,6 +1111,8 @@ def call( attention_mask=inputs["decoder_attention_mask"], encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], + head_mask=inputs["decoder_head_mask"], + encoder_head_mask=inputs["head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1076,6 +1165,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1094,6 +1185,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1111,6 +1204,8 @@ def call( attention_mask=inputs["attention_mask"], decoder_input_ids=inputs["decoder_input_ids"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1189,6 +1284,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1217,6 +1314,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1241,6 +1340,8 @@ def call( decoder_input_ids=inputs["decoder_input_ids"], encoder_outputs=inputs["encoder_outputs"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -1287,7 +1388,15 @@ def serving_output(self, output): ) # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past, + attention_mask, + head_mask=None, + use_cache=None, + **kwargs, + ) -> Dict: assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" if len(past) == 1: assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" @@ -1319,6 +1428,7 @@ def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index f3871a5d3676..b36e36c43713 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -197,6 +197,7 @@ def call( key_value_states: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None, attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, training=False, ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]: """Input shape: Batch x Time x Channel""" @@ -263,6 +264,17 @@ def call( attn_weights = tf.nn.softmax(attn_weights, axis=-1) + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + attn_probs = self.dropout(attn_weights, training=training) attn_output = tf.matmul(attn_probs, value_states) @@ -300,17 +312,19 @@ def __init__(self, config: PegasusConfig, **kwargs): self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2") self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") - def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False): + def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False): """ Args: hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (:obj:`tf.Tensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)` """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, self_attn_weights, _ = self.self_attn( - hidden_states=hidden_states, attention_mask=attention_mask + hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) tf.debugging.assert_equal( shape_list(hidden_states), @@ -366,6 +380,8 @@ def call( attention_mask: Optional[tf.Tensor] = None, encoder_hidden_states: Optional[tf.Tensor] = None, encoder_attention_mask: Optional[tf.Tensor] = None, + layer_head_mask: Optional[tf.Tensor] = None, + encoder_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: @@ -377,6 +393,10 @@ def call( encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size + `(decoder_attention_heads,)` + encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size + `(encoder_attention_heads,)` past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states """ residual = hidden_states @@ -390,6 +410,7 @@ def call( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, ) hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -406,6 +427,7 @@ def call( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, + layer_head_mask=encoder_layer_head_mask, past_key_value=cross_attn_past_key_value, ) hidden_states = self.dropout(hidden_states, training=training) @@ -553,6 +575,18 @@ def serving(self, inputs): the right for denoising pre-training following the paper. decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): will be made by default and ignore pad tokens. It is not recommended to set this for most use cases. + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tf.FloatTensor`, `optional`): hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of @@ -618,6 +652,7 @@ def call( input_ids=None, inputs_embeds=None, attention_mask=None, + head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -642,6 +677,12 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices @@ -660,6 +701,7 @@ def call( config=self.config, input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -694,8 +736,15 @@ def call( encoder_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) # encoder layers - for encoder_layer in self.layers: + for idx, encoder_layer in enumerate(self.layers): if inputs["output_hidden_states"]: encoder_states = encoder_states + (hidden_states,) @@ -704,7 +753,11 @@ def call( if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer continue - hidden_states, attn = encoder_layer(hidden_states, attention_mask) + hidden_states, attn = encoder_layer( + hidden_states, + attention_mask, + inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + ) if inputs["output_attentions"]: all_attentions += (attn,) @@ -762,6 +815,8 @@ def call( attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + encoder_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, @@ -799,6 +854,19 @@ def call( - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + + encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the heas is **masked**. + past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -827,6 +895,8 @@ def call( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_head_mask=encoder_head_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, @@ -881,6 +951,14 @@ def call( all_hidden_states = () all_self_attns = () present_key_values = () + + # check if head_mask has a correct number of layers specified if desired + if inputs["head_mask"] is not None: + tf.debugging.assert_equal( + shape_list(inputs["head_mask"])[0], + len(self.layers), + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: @@ -897,6 +975,10 @@ def call( attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], encoder_attention_mask=inputs["encoder_attention_mask"], + layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + encoder_layer_head_mask=inputs["encoder_head_mask"][idx] + if inputs["encoder_head_mask"] is not None + else None, past_key_value=past_key_value, ) @@ -969,6 +1051,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -987,6 +1071,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1012,6 +1098,7 @@ def call( inputs["encoder_outputs"] = self.encoder( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], + head_mask=inputs["head_mask"], inputs_embeds=inputs["inputs_embeds"], output_attentions=inputs["output_attentions"], output_hidden_states=inputs["output_hidden_states"], @@ -1034,6 +1121,8 @@ def call( attention_mask=inputs["decoder_attention_mask"], encoder_hidden_states=inputs["encoder_outputs"][0], encoder_attention_mask=inputs["attention_mask"], + head_mask=inputs["decoder_head_mask"], + encoder_head_mask=inputs["head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["decoder_inputs_embeds"], use_cache=inputs["use_cache"], @@ -1086,6 +1175,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None, past_key_values=None, inputs_embeds=None, @@ -1104,6 +1195,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1121,6 +1214,8 @@ def call( attention_mask=inputs["attention_mask"], decoder_input_ids=inputs["decoder_input_ids"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], encoder_outputs=inputs["encoder_outputs"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], @@ -1199,6 +1294,8 @@ def call( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, encoder_outputs: Optional[TFBaseModelOutput] = None, past_key_values=None, inputs_embeds=None, @@ -1227,6 +1324,8 @@ def call( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1253,6 +1352,8 @@ def call( decoder_input_ids=inputs["decoder_input_ids"], encoder_outputs=inputs["encoder_outputs"], decoder_attention_mask=inputs["decoder_attention_mask"], + head_mask=inputs["head_mask"], + decoder_head_mask=inputs["decoder_head_mask"], past_key_values=inputs["past_key_values"], inputs_embeds=inputs["inputs_embeds"], decoder_inputs_embeds=inputs["decoder_inputs_embeds"], @@ -1299,7 +1400,15 @@ def serving_output(self, output): ) # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past, + attention_mask, + head_mask=None, + use_cache=None, + **kwargs, + ) -> Dict: assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" if len(past) == 1: assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" @@ -1331,6 +1440,7 @@ def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/tests/test_modeling_tf_albert.py b/tests/test_modeling_tf_albert.py index 5902c2b01fb3..d037738081b4 100644 --- a/tests/test_modeling_tf_albert.py +++ b/tests/test_modeling_tf_albert.py @@ -240,6 +240,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) + test_head_masking = False def setUp(self): self.model_tester = TFAlbertModelTester(self) diff --git a/tests/test_modeling_tf_bart.py b/tests/test_modeling_tf_bart.py index 0f05bf76bdfe..6a0110d4a6d6 100644 --- a/tests/test_modeling_tf_bart.py +++ b/tests/test_modeling_tf_bart.py @@ -108,10 +108,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): input_ids = input_ids[:1, :] attention_mask = inputs_dict["attention_mask"][:1, :] + head_mask = inputs_dict["head_mask"] self.batch_size = 1 # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() past_key_values = past_key_values[1] @@ -144,6 +145,8 @@ def prepare_bart_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -155,11 +158,17 @@ def prepare_bart_inputs_dict( ], axis=-1, ) + if head_mask is None: + head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) + if decoder_head_mask is None: + decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": head_mask, } @@ -169,6 +178,7 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_head_masking = True def setUp(self): self.model_tester = TFBartModelTester(self) diff --git a/tests/test_modeling_tf_bert.py b/tests/test_modeling_tf_bert.py index 7289a5d47332..6043f7926b0a 100644 --- a/tests/test_modeling_tf_bert.py +++ b/tests/test_modeling_tf_bert.py @@ -273,6 +273,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) + test_head_masking = False # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): diff --git a/tests/test_modeling_tf_blenderbot.py b/tests/test_modeling_tf_blenderbot.py index 5ec061deceaa..c38cd66efb2b 100644 --- a/tests/test_modeling_tf_blenderbot.py +++ b/tests/test_modeling_tf_blenderbot.py @@ -107,10 +107,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): input_ids = input_ids[:1, :] attention_mask = inputs_dict["attention_mask"][:1, :] + head_mask = inputs_dict["head_mask"] self.batch_size = 1 # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() past_key_values = past_key_values[1] @@ -143,6 +144,8 @@ def prepare_blenderbot_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -154,11 +157,17 @@ def prepare_blenderbot_inputs_dict( ], axis=-1, ) + if head_mask is None: + head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) + if decoder_head_mask is None: + decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -168,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_head_masking = True def setUp(self): self.model_tester = TFBlenderbotModelTester(self) diff --git a/tests/test_modeling_tf_blenderbot_small.py b/tests/test_modeling_tf_blenderbot_small.py index 8b657f9cb54d..8b322a7b7c87 100644 --- a/tests/test_modeling_tf_blenderbot_small.py +++ b/tests/test_modeling_tf_blenderbot_small.py @@ -107,10 +107,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): input_ids = input_ids[:1, :] attention_mask = inputs_dict["attention_mask"][:1, :] + head_mask = inputs_dict["head_mask"] self.batch_size = 1 # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() past_key_values = past_key_values[1] @@ -143,6 +144,8 @@ def prepare_blenderbot_small_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -154,11 +157,17 @@ def prepare_blenderbot_small_inputs_dict( ], axis=-1, ) + if head_mask is None: + head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) + if decoder_head_mask is None: + decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -170,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_head_masking = True def setUp(self): self.model_tester = TFBlenderbotSmallModelTester(self) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index b41209f7db82..f41c4b146972 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -440,6 +440,11 @@ def test_pt_tf_model_equivalence(self): def test_train_pipeline_custom_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # head_mask and decoder_head_mask has different shapes than other input args + if "head_mask" in inputs_dict: + del inputs_dict["head_mask"] + if "decoder_head_mask" in inputs_dict: + del inputs_dict["decoder_head_mask"] tf_main_layer_classes = set( module_member for model_class in self.all_model_classes @@ -620,6 +625,75 @@ def check_encoder_attentions_output(outputs): self.assertEqual(model.config.output_hidden_states, True) check_encoder_attentions_output(outputs) + def test_headmasking(self): + if not self.test_head_masking: + return + + random.Random().seed(42) + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + random.Random().seed() + + inputs_dict["output_attentions"] = True + config.output_hidden_states = True + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + + # Prepare head_mask + def prepare_layer_head_mask(i, attention_heads, num_hidden_layers): + if i == 0: + return tf.concat( + (tf.zeros(1, dtype=tf.float32), tf.ones(attention_heads - 1, dtype=tf.float32)), 0 + ) + elif i == num_hidden_layers - 1: + return tf.concat( + (tf.zeros(attention_heads - 1, dtype=tf.float32), tf.ones(1, dtype=tf.float32)), 0 + ) + else: + return tf.ones(attention_heads, dtype=tf.float32) + + head_mask = tf.stack( + [ + prepare_layer_head_mask(i, config.num_attention_heads, config.num_hidden_layers) + for i in range(config.num_hidden_layers) + ], + 0, + ) + + inputs = self._prepare_for_class(inputs_dict, model_class).copy() + inputs["head_mask"] = head_mask + if model.config.is_encoder_decoder: + signature = inspect.signature(model.call) + arg_names = [*signature.parameters.keys()] + if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model + inputs["decoder_head_mask"] = head_mask + + outputs = model(**inputs, return_dict=True) + + def check_attentions_validity(attentions): + # Remove Nan + for t in attentions: + self.assertLess( + (tf.math.reduce_sum(tf.cast(tf.math.is_nan(t), tf.float32))).numpy(), (tf.size(t) / 4).numpy() + ) # Check we don't have more than 25% nans (arbitrary) + + attentions = [ + tf.where(tf.math.is_nan(t), 0.0, t) for t in attentions + ] # remove them (the test is less complete) + + self.assertAlmostEqual(tf.math.reduce_sum(attentions[0][..., 0, :, :]).numpy(), 0.0) + self.assertNotEqual(tf.math.reduce_sum(attentions[0][..., -1, :, :]).numpy(), 0.0) + if len(attentions) > 2: # encoder-decodere models have only 2 layers in each modules + self.assertNotEqual(tf.math.reduce_sum(attentions[1][..., 0, :, :]).numpy(), 0.0) + self.assertAlmostEqual(tf.math.reduce_sum(attentions[-1][..., -2, :, :]).numpy(), 0.0) + self.assertNotEqual(tf.math.reduce_sum(attentions[-1][..., -1, :, :]).numpy(), 0.0) + + if model.config.is_encoder_decoder: + check_attentions_validity(outputs.encoder_attentions) + check_attentions_validity(outputs.decoder_attentions) + else: + check_attentions_validity(outputs.attentions) + def test_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_ctrl.py b/tests/test_modeling_tf_ctrl.py index e2bcb69709f7..5664c968ce03 100644 --- a/tests/test_modeling_tf_ctrl.py +++ b/tests/test_modeling_tf_ctrl.py @@ -173,6 +173,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else () all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else () + test_head_masking = False def setUp(self): self.model_tester = TFCTRLModelTester(self) diff --git a/tests/test_modeling_tf_distilbert.py b/tests/test_modeling_tf_distilbert.py index 9d676173ba26..3c1b755ccc23 100644 --- a/tests/test_modeling_tf_distilbert.py +++ b/tests/test_modeling_tf_distilbert.py @@ -183,6 +183,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else None ) + test_head_masking = False def setUp(self): self.model_tester = TFDistilBertModelTester(self) diff --git a/tests/test_modeling_tf_electra.py b/tests/test_modeling_tf_electra.py index 30eb576e7328..6cdd0e0c6d8d 100644 --- a/tests/test_modeling_tf_electra.py +++ b/tests/test_modeling_tf_electra.py @@ -205,6 +205,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) + test_head_masking = False def setUp(self): self.model_tester = TFElectraModelTester(self) diff --git a/tests/test_modeling_tf_flaubert.py b/tests/test_modeling_tf_flaubert.py index 56eddaea6947..640dd9cfe863 100644 --- a/tests/test_modeling_tf_flaubert.py +++ b/tests/test_modeling_tf_flaubert.py @@ -291,6 +291,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = ( (TFFlaubertWithLMHeadModel,) if is_tf_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable + test_head_masking = False def setUp(self): self.model_tester = TFFlaubertModelTester(self) diff --git a/tests/test_modeling_tf_funnel.py b/tests/test_modeling_tf_funnel.py index ab96f3143309..a530bd59a0e2 100644 --- a/tests/test_modeling_tf_funnel.py +++ b/tests/test_modeling_tf_funnel.py @@ -338,6 +338,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) + test_head_masking = False def setUp(self): self.model_tester = TFFunnelModelTester(self) @@ -376,6 +377,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = ( (TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else () ) + test_head_masking = False def setUp(self): self.model_tester = TFFunnelModelTester(self, base=True) diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index cfdedae723e8..321f87b5e9e3 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -332,6 +332,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): else () ) all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else () + test_head_masking = False def setUp(self): self.model_tester = TFGPT2ModelTester(self) diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index 40a57f243323..c4f611be6a51 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -187,6 +187,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_head_masking = False def setUp(self): self.model_tester = TFLEDModelTester(self) diff --git a/tests/test_modeling_tf_longformer.py b/tests/test_modeling_tf_longformer.py index c76b338e105b..b56002ed050c 100644 --- a/tests/test_modeling_tf_longformer.py +++ b/tests/test_modeling_tf_longformer.py @@ -297,6 +297,7 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) + test_head_masking = False def setUp(self): self.model_tester = TFLongformerModelTester(self) diff --git a/tests/test_modeling_tf_lxmert.py b/tests/test_modeling_tf_lxmert.py index eceff0c48efc..496fb1df2671 100644 --- a/tests/test_modeling_tf_lxmert.py +++ b/tests/test_modeling_tf_lxmert.py @@ -361,6 +361,7 @@ def create_and_check_lxmert_for_pretraining( class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else () + test_head_masking = False def setUp(self): self.model_tester = TFLxmertModelTester(self) diff --git a/tests/test_modeling_tf_marian.py b/tests/test_modeling_tf_marian.py index 569e43b25d13..917c5f5fb3dd 100644 --- a/tests/test_modeling_tf_marian.py +++ b/tests/test_modeling_tf_marian.py @@ -109,10 +109,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): input_ids = input_ids[:1, :] attention_mask = inputs_dict["attention_mask"][:1, :] + head_mask = inputs_dict["head_mask"] self.batch_size = 1 # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() past_key_values = past_key_values[1] @@ -145,6 +146,8 @@ def prepare_marian_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -156,11 +159,17 @@ def prepare_marian_inputs_dict( ], axis=-1, ) + if head_mask is None: + head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) + if decoder_head_mask is None: + decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -170,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_head_masking = True def setUp(self): self.model_tester = TFMarianModelTester(self) diff --git a/tests/test_modeling_tf_mbart.py b/tests/test_modeling_tf_mbart.py index ffc98ee66317..a3b6cd416b2b 100644 --- a/tests/test_modeling_tf_mbart.py +++ b/tests/test_modeling_tf_mbart.py @@ -106,10 +106,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): input_ids = input_ids[:1, :] attention_mask = inputs_dict["attention_mask"][:1, :] + head_mask = inputs_dict["head_mask"] self.batch_size = 1 # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() past_key_values = past_key_values[1] @@ -147,6 +148,8 @@ def prepare_mbart_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -158,11 +161,17 @@ def prepare_mbart_inputs_dict( ], axis=-1, ) + if head_mask is None: + head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) + if decoder_head_mask is None: + decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": head_mask, } @@ -172,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_head_masking = True def setUp(self): self.model_tester = TFMBartModelTester(self) diff --git a/tests/test_modeling_tf_mobilebert.py b/tests/test_modeling_tf_mobilebert.py index 8995e6cf9bb7..cf918f9b30d4 100644 --- a/tests/test_modeling_tf_mobilebert.py +++ b/tests/test_modeling_tf_mobilebert.py @@ -55,6 +55,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) + test_head_masking = False class TFMobileBertModelTester(object): def __init__( diff --git a/tests/test_modeling_tf_mpnet.py b/tests/test_modeling_tf_mpnet.py index d51b4e30b454..5aa66b527957 100644 --- a/tests/test_modeling_tf_mpnet.py +++ b/tests/test_modeling_tf_mpnet.py @@ -198,6 +198,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) + test_head_masking = False def setUp(self): self.model_tester = TFMPNetModelTester(self) diff --git a/tests/test_modeling_tf_openai.py b/tests/test_modeling_tf_openai.py index a7ffe33c5225..49689a0e4d46 100644 --- a/tests/test_modeling_tf_openai.py +++ b/tests/test_modeling_tf_openai.py @@ -202,6 +202,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = ( (TFOpenAIGPTLMHeadModel,) if is_tf_available() else () ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly + test_head_masking = False def setUp(self): self.model_tester = TFOpenAIGPTModelTester(self) diff --git a/tests/test_modeling_tf_pegasus.py b/tests/test_modeling_tf_pegasus.py index 995e8abb8ee2..63dc33604405 100644 --- a/tests/test_modeling_tf_pegasus.py +++ b/tests/test_modeling_tf_pegasus.py @@ -107,10 +107,11 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): input_ids = input_ids[:1, :] attention_mask = inputs_dict["attention_mask"][:1, :] + head_mask = inputs_dict["head_mask"] self.batch_size = 1 # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) output, past_key_values = outputs.to_tuple() past_key_values = past_key_values[1] @@ -143,6 +144,8 @@ def prepare_pegasus_inputs_dict( decoder_input_ids, attention_mask=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, ): if attention_mask is None: attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8) @@ -154,11 +157,17 @@ def prepare_pegasus_inputs_dict( ], axis=-1, ) + if head_mask is None: + head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads)) + if decoder_head_mask is None: + decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads)) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, } @@ -168,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else () is_encoder_decoder = True test_pruning = False + test_head_masking = True def setUp(self): self.model_tester = TFPegasusModelTester(self) diff --git a/tests/test_modeling_tf_roberta.py b/tests/test_modeling_tf_roberta.py index d628606503e7..66cb128c8ac5 100644 --- a/tests/test_modeling_tf_roberta.py +++ b/tests/test_modeling_tf_roberta.py @@ -185,6 +185,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): if is_tf_available() else () ) + test_head_masking = False def setUp(self): self.model_tester = TFRobertaModelTester(self) diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index f4d0f2ce19d4..f4a235f8feec 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -248,6 +248,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): is_encoder_decoder = True all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else () all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else () + test_head_masking = False def setUp(self): self.model_tester = TFT5ModelTester(self) @@ -417,6 +418,7 @@ def prepare_config_and_inputs_for_common(self): class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): is_encoder_decoder = False all_model_classes = (TFT5EncoderModel,) if is_tf_available() else () + test_head_masking = False def setUp(self): self.model_tester = TFT5EncoderOnlyModelTester(self) diff --git a/tests/test_modeling_tf_transfo_xl.py b/tests/test_modeling_tf_transfo_xl.py index 14d3e5a3acf7..64cc315c37f9 100644 --- a/tests/test_modeling_tf_transfo_xl.py +++ b/tests/test_modeling_tf_transfo_xl.py @@ -163,6 +163,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = () if is_tf_available() else () # TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented test_resize_embeddings = False + test_head_masking = False def setUp(self): self.model_tester = TFTransfoXLModelTester(self) diff --git a/tests/test_modeling_tf_xlm.py b/tests/test_modeling_tf_xlm.py index dad476394500..c6a2d6c94a5c 100644 --- a/tests/test_modeling_tf_xlm.py +++ b/tests/test_modeling_tf_xlm.py @@ -293,6 +293,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = ( (TFXLMWithLMHeadModel,) if is_tf_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable + test_head_masking = False def setUp(self): self.model_tester = TFXLMModelTester(self) diff --git a/tests/test_modeling_tf_xlnet.py b/tests/test_modeling_tf_xlnet.py index bf0ffbda199b..f9ea93c21c89 100644 --- a/tests/test_modeling_tf_xlnet.py +++ b/tests/test_modeling_tf_xlnet.py @@ -347,6 +347,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase): all_generative_model_classes = ( (TFXLNetLMHeadModel,) if is_tf_available() else () ) # TODO (PVP): Check other models whether language generation is also applicable + test_head_masking = False def setUp(self): self.model_tester = TFXLNetModelTester(self)