Skip to content

Commit

Permalink
Remove unused token_type_ids in MPNet (#9564)
Browse files Browse the repository at this point in the history
* Add warning

* Remove unused import

* Fix missing call

* Fix missing call

* Completely remove token_type_ids

* Apply style

* Remove unused import

* Update src/transformers/models/mpnet/modeling_tf_mpnet.py

Co-authored-by: Lysandre Debut <[email protected]>

Co-authored-by: Lysandre Debut <[email protected]>
  • Loading branch information
jplu and LysandreJik authored Jan 15, 2021
1 parent 90ca8d3 commit 8eba1f8
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 53 deletions.
14 changes: 3 additions & 11 deletions src/transformers/models/mpnet/modeling_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,6 @@ def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -617,7 +616,6 @@ def forward(
outputs = self.mpnet(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -701,7 +699,6 @@ def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -716,12 +713,12 @@ def forward(
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.mpnet(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -783,7 +780,6 @@ def __init__(self, config):
def forward(
self,
input_ids=None,
token_type_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
Expand All @@ -799,12 +795,12 @@ def forward(
num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
flat_inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
Expand All @@ -815,7 +811,6 @@ def forward(
outputs = self.mpnet(
flat_input_ids,
position_ids=flat_position_ids,
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds,
Expand Down Expand Up @@ -878,7 +873,6 @@ def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -892,12 +886,12 @@ def forward(
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.mpnet(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -987,7 +981,6 @@ def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -1013,7 +1006,6 @@ def forward(
outputs = self.mpnet(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand Down
52 changes: 10 additions & 42 deletions src/transformers/models/mpnet/modeling_tf_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
mode="embedding",
training=False,
Expand All @@ -156,7 +155,7 @@ def call(
Get token embeddings of inputs
Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
inputs: list of two int64 tensors with shape [batch_size, length]: (input_ids, position_ids)
mode: string, a valid value is one of "embedding" and "linear"
Returns:
Expand All @@ -169,13 +168,13 @@ def call(
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if mode == "embedding":
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
return self._embedding(input_ids, position_ids, inputs_embeds, training=training)
elif mode == "linear":
return self._linear(input_ids)
else:
raise ValueError("mode {} is not valid.".format(mode))

def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
def _embedding(self, input_ids, position_ids, inputs_embeds, training=False):
"""Applies embedding based on inputs tensor."""
assert not (input_ids is None and inputs_embeds is None)

Expand Down Expand Up @@ -552,12 +551,10 @@ class PreTrainedModel
"""
raise NotImplementedError

# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -572,7 +569,6 @@ def call(
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -582,6 +578,7 @@ def call(
training=training,
kwargs_call=kwargs,
)

if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
Expand All @@ -594,13 +591,9 @@ def call(
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)

if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)

embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
Expand Down Expand Up @@ -682,9 +675,9 @@ def call(
- a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(inputs_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
:obj:`model([input_ids, attention_mask])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
:obj:`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
:obj:`model({"input_ids": input_ids, "attention_mask": attention_mask})`
Args:
config (:class:`~transformers.MPNetConfig`): Model configuration class with all the parameters of the model.
Expand All @@ -710,14 +703,6 @@ def call(
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:
- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
Expand Down Expand Up @@ -767,7 +752,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -782,7 +766,6 @@ def call(
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -795,7 +778,6 @@ def call(
outputs = self.mpnet(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
Expand Down Expand Up @@ -895,7 +877,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -912,12 +893,12 @@ def call(
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""

inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -931,7 +912,6 @@ def call(
outputs = self.mpnet(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
Expand Down Expand Up @@ -1018,7 +998,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -1035,12 +1014,12 @@ def call(
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""

inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -1054,7 +1033,6 @@ def call(
outputs = self.mpnet(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
Expand Down Expand Up @@ -1126,7 +1104,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -1148,7 +1125,6 @@ def call(
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -1171,9 +1147,6 @@ def call(
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
Expand All @@ -1185,7 +1158,6 @@ def call(
outputs = self.mpnet(
flat_input_ids,
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
inputs["head_mask"],
flat_inputs_embeds,
Expand Down Expand Up @@ -1264,7 +1236,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -1280,12 +1251,12 @@ def call(
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""

inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -1299,7 +1270,6 @@ def call(
outputs = self.mpnet(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
Expand Down Expand Up @@ -1365,7 +1335,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -1387,12 +1356,12 @@ def call(
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""

inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -1407,7 +1376,6 @@ def call(
outputs = self.mpnet(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mpnet/tokenization_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class MPNetTokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["attention_mask"]

def __init__(
self,
Expand Down
Loading

0 comments on commit 8eba1f8

Please sign in to comment.