diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 0f920a80d488..baf1d962c05d 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -46,6 +46,7 @@ TFCausalLanguageModelingLoss, TFMaskedLanguageModelingLoss, TFMultipleChoiceLoss, + TFNextSentencePredictionLoss, TFPreTrainedModel, TFQuestionAnsweringLoss, TFSequenceClassificationLoss, @@ -1036,7 +1037,7 @@ def call( """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING, ) -class TFBertForNextSentencePrediction(TFBertPreTrainedModel): +class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1045,7 +1046,20 @@ def __init__(self, config, *inputs, **kwargs): @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) - def call(self, inputs, **kwargs): + def call( + self, + inputs=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + next_sentence_label=None, + training=False, + ): r""" Return: @@ -1064,17 +1078,43 @@ def call(self, inputs, **kwargs): >>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0] >>> assert logits[0][0] < logits[0][1] # the next sentence was random """ - return_dict = kwargs.get("return_dict") return_dict = return_dict if return_dict is not None else self.bert.return_dict - outputs = self.bert(inputs, **kwargs) + + if isinstance(inputs, (tuple, list)): + next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label + if len(inputs) > 9: + inputs = inputs[:9] + elif isinstance(inputs, (dict, BatchEncoding)): + next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label) + + outputs = self.bert( + inputs, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) pooled_output = outputs[1] - seq_relationship_score = self.nsp(pooled_output) + seq_relationship_scores = self.nsp(pooled_output) + + next_sentence_loss = ( + None + if next_sentence_label is None + else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores) + ) if not return_dict: - return (seq_relationship_score,) + outputs[2:] + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output return TFNextSentencePredictorOutput( - logits=seq_relationship_score, + loss=next_sentence_loss, + logits=seq_relationship_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) diff --git a/src/transformers/modeling_tf_mobilebert.py b/src/transformers/modeling_tf_mobilebert.py index 9917730887ae..cf3bbde94f7b 100644 --- a/src/transformers/modeling_tf_mobilebert.py +++ b/src/transformers/modeling_tf_mobilebert.py @@ -44,6 +44,7 @@ from .modeling_tf_utils import ( TFMaskedLanguageModelingLoss, TFMultipleChoiceLoss, + TFNextSentencePredictionLoss, TFPreTrainedModel, TFQuestionAnsweringLoss, TFSequenceClassificationLoss, @@ -1119,7 +1120,7 @@ def call(self, pooled_output): """MobileBert Model with a `next sentence prediction (classification)` head on top. """, MOBILEBERT_START_DOCSTRING, ) -class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel): +class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextSentencePredictionLoss): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1128,7 +1129,20 @@ def __init__(self, config, *inputs, **kwargs): @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) - def call(self, inputs, **kwargs): + def call( + self, + inputs=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + next_sentence_label=None, + training=False, + ): r""" Return: @@ -1146,18 +1160,44 @@ def call(self, inputs, **kwargs): >>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0] """ - return_dict = kwargs.get("return_dict") return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict - outputs = self.mobilebert(inputs, **kwargs) + + if isinstance(inputs, (tuple, list)): + next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label + if len(inputs) > 9: + inputs = inputs[:9] + elif isinstance(inputs, (dict, BatchEncoding)): + next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label) + + outputs = self.mobilebert( + inputs, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) pooled_output = outputs[1] - seq_relationship_score = self.cls(pooled_output) + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = ( + None + if next_sentence_label is None + else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores) + ) if not return_dict: - return (seq_relationship_score,) + outputs[2:] + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output return TFNextSentencePredictorOutput( - logits=seq_relationship_score, + loss=next_sentence_loss, + logits=seq_relationship_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index 9a4deb65b0d5..8d7d53814c48 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -307,6 +307,8 @@ class TFNextSentencePredictorOutput(ModelOutput): Base class for outputs of models predicting if two sentences are consecutive or not. Args: + loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided): + Next sentence prediction loss. logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`): Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). @@ -323,6 +325,7 @@ class TFNextSentencePredictorOutput(ModelOutput): heads. """ + loss: tf.Tensor = None logits: tf.Tensor = None hidden_states: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 5ab039967d17..2de2b1f0eecb 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -215,6 +215,27 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss): """ +class TFNextSentencePredictionLoss: + """ + Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence. + + .. note:: + Any label of -100 will be ignored (along with the corresponding logits) in the loss computation. + """ + + def compute_loss(self, labels, logits): + loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction=tf.keras.losses.Reduction.NONE + ) + # make sure only labels that are not equal to -100 + # are taken into account as loss + next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) + next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss) + next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss) + + return loss_fn(next_sentence_label, next_sentence_reduced_logits) + + def detect_tf_missing_unexpected_layers(model, resolved_archive_file): """ Detect missing and unexpected layers. diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 3bb40af4ef74..30321fc5e6d9 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -35,6 +35,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, @@ -95,6 +96,8 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> d inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(): inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) + elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(): + inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) elif model_class in [ *TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(), *TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),