Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused token_type_ids in MPNet #9564

Merged
merged 8 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions src/transformers/models/mpnet/modeling_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,6 @@ def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -617,7 +616,6 @@ def forward(
outputs = self.mpnet(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -701,7 +699,6 @@ def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -716,12 +713,12 @@ def forward(
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.mpnet(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -783,7 +780,6 @@ def __init__(self, config):
def forward(
self,
input_ids=None,
token_type_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
Expand All @@ -799,12 +795,12 @@ def forward(
num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
flat_inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
Expand All @@ -815,7 +811,6 @@ def forward(
outputs = self.mpnet(
flat_input_ids,
position_ids=flat_position_ids,
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds,
Expand Down Expand Up @@ -878,7 +873,6 @@ def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -892,12 +886,12 @@ def forward(
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.mpnet(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -987,7 +981,6 @@ def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -1013,7 +1006,6 @@ def forward(
outputs = self.mpnet(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand Down
52 changes: 10 additions & 42 deletions src/transformers/models/mpnet/modeling_tf_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
mode="embedding",
training=False,
Expand All @@ -156,7 +155,7 @@ def call(
Get token embeddings of inputs

Args:
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
inputs: list of two int64 tensors with shape [batch_size, length]: (input_ids, position_ids)
mode: string, a valid value is one of "embedding" and "linear"

Returns:
Expand All @@ -169,13 +168,13 @@ def call(
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
if mode == "embedding":
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
return self._embedding(input_ids, position_ids, inputs_embeds, training=training)
elif mode == "linear":
return self._linear(input_ids)
else:
raise ValueError("mode {} is not valid.".format(mode))

def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
def _embedding(self, input_ids, position_ids, inputs_embeds, training=False):
"""Applies embedding based on inputs tensor."""
assert not (input_ids is None and inputs_embeds is None)

Expand Down Expand Up @@ -552,12 +551,10 @@ class PreTrainedModel
"""
raise NotImplementedError

# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -572,7 +569,6 @@ def call(
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -582,6 +578,7 @@ def call(
training=training,
kwargs_call=kwargs,
)

if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
Expand All @@ -594,13 +591,9 @@ def call(
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(input_shape, 1)

if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0)

embedding_output = self.embeddings(
inputs["input_ids"],
inputs["position_ids"],
inputs["token_type_ids"],
inputs["inputs_embeds"],
training=inputs["training"],
)
Expand Down Expand Up @@ -682,9 +675,9 @@ def call(

- a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(inputs_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
:obj:`model([input_ids, attention_mask])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
:obj:`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
:obj:`model({"input_ids": input_ids, "attention_mask": attention_mask})`

Args:
config (:class:`~transformers.MPNetConfig`): Model configuration class with all the parameters of the model.
Expand All @@ -710,14 +703,6 @@ def call(
- 0 for tokens that are **masked**.

`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:

- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.

`What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
Expand Down Expand Up @@ -767,7 +752,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -782,7 +766,6 @@ def call(
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -795,7 +778,6 @@ def call(
outputs = self.mpnet(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
Expand Down Expand Up @@ -895,7 +877,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -912,12 +893,12 @@ def call(
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""

inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -931,7 +912,6 @@ def call(
outputs = self.mpnet(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
Expand Down Expand Up @@ -1018,7 +998,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -1035,12 +1014,12 @@ def call(
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""

inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -1054,7 +1033,6 @@ def call(
outputs = self.mpnet(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
Expand Down Expand Up @@ -1126,7 +1104,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -1148,7 +1125,6 @@ def call(
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -1171,9 +1147,6 @@ def call(
flat_attention_mask = (
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None
)
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
Expand All @@ -1185,7 +1158,6 @@ def call(
outputs = self.mpnet(
flat_input_ids,
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
inputs["head_mask"],
flat_inputs_embeds,
Expand Down Expand Up @@ -1264,7 +1236,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -1280,12 +1251,12 @@ def call(
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""

inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -1299,7 +1270,6 @@ def call(
outputs = self.mpnet(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
Expand Down Expand Up @@ -1365,7 +1335,6 @@ def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
Expand All @@ -1387,12 +1356,12 @@ def call(
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""

inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand All @@ -1407,7 +1376,6 @@ def call(
outputs = self.mpnet(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mpnet/tokenization_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class MPNetTokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["attention_mask"]

def __init__(
self,
Expand Down
Loading