Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add next sentence prediction loss computation #8462

Merged
merged 7 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions src/transformers/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
TFCausalLanguageModelingLoss,
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFNextSentencePredictionLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
Expand Down Expand Up @@ -1036,7 +1037,7 @@ def call(
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
)
class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

Expand All @@ -1045,7 +1046,20 @@ def __init__(self, config, *inputs, **kwargs):

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs):
def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
next_sentence_label=None,
training=False,
):
r"""
Return:

Expand All @@ -1064,17 +1078,43 @@ def call(self, inputs, **kwargs):
>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
>>> assert logits[0][0] < logits[0][1] # the next sentence was random
"""
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs)

if isinstance(inputs, (tuple, list)):
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)

outputs = self.bert(
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
pooled_output = outputs[1]
seq_relationship_score = self.nsp(pooled_output)
seq_relationship_scores = self.nsp(pooled_output)

next_sentence_loss = (
None
if next_sentence_label is None
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
)

if not return_dict:
return (seq_relationship_score,) + outputs[2:]
output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output

return TFNextSentencePredictorOutput(
logits=seq_relationship_score,
loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Expand Down
54 changes: 47 additions & 7 deletions src/transformers/modeling_tf_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFNextSentencePredictionLoss,
TFPreTrainedModel,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
Expand Down Expand Up @@ -1119,7 +1120,7 @@ def call(self, pooled_output):
"""MobileBert Model with a `next sentence prediction (classification)` head on top. """,
MOBILEBERT_START_DOCSTRING,
)
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextSentencePredictionLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

Expand All @@ -1128,7 +1129,20 @@ def __init__(self, config, *inputs, **kwargs):

@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs):
def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
next_sentence_label=None,
training=False,
):
r"""
Return:

Expand All @@ -1146,18 +1160,44 @@ def call(self, inputs, **kwargs):

>>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
"""
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.mobilebert.return_dict
outputs = self.mobilebert(inputs, **kwargs)

if isinstance(inputs, (tuple, list)):
next_sentence_label = inputs[9] if len(inputs) > 9 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)

outputs = self.mobilebert(
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)

pooled_output = outputs[1]
seq_relationship_score = self.cls(pooled_output)
seq_relationship_scores = self.cls(pooled_output)

next_sentence_loss = (
None
if next_sentence_label is None
else self.compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
)

if not return_dict:
return (seq_relationship_score,) + outputs[2:]
output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output

return TFNextSentencePredictorOutput(
logits=seq_relationship_score,
loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_tf_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ class TFNextSentencePredictorOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.

Args:
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
Next sentence prediction loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
Expand All @@ -323,6 +325,7 @@ class TFNextSentencePredictorOutput(ModelOutput):
heads.
"""

loss: tf.Tensor = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,27 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
"""


class TFNextSentencePredictionLoss:
"""
Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.

.. note::
Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
"""

def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)

return loss_fn(next_sentence_label, next_sentence_reduced_logits)


def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
"""
Detect missing and unexpected layers.
Expand Down
3 changes: 3 additions & 0 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
Expand Down Expand Up @@ -95,6 +96,8 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> d
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in [
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
Expand Down