Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix model templates #9999

Merged
merged 1 commit into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ def forward(
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.

Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
Indices can be obtained using :class:`~transformers.BartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,7 +1425,7 @@ def forward(
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.

Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
Indices can be obtained using :class:`~transformers.BlenderbotTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ def forward(
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.

Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
Indices can be obtained using :class:`~transformers.BlenderbotSmallTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,7 @@ def forward(
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.

Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
Indices can be obtained using :class:`~transformers.MarianTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,7 +1658,7 @@ def forward(
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.

Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
Indices can be obtained using :class:`~transformers.MBartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/pegasus/modeling_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,7 @@ def forward(
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.

Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
Indices can be obtained using :class:`~transformers.PegasusTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
Expand Down Expand Up @@ -114,6 +115,7 @@
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,7 @@ def forward(
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
CausalLMOutputWithCrossAttentions
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
Expand Down Expand Up @@ -1952,7 +1953,7 @@ def forward(
return outputs


# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ClassificationHead with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""

Expand Down Expand Up @@ -3066,8 +3067,8 @@ def forward(
encoder_attentions=outputs.encoder_attentions,
)

# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}DecoderWrapper with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PretrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
Expand All @@ -3081,8 +3082,8 @@ def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ForCausalLM with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PretrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config):
super().__init__(config)
config = copy.deepcopy(config)
Expand Down Expand Up @@ -3199,8 +3200,8 @@ def forward(

>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM

>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('{{cookiecutter.checkpoint_identifier}}', add_cross_attention=False)
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('facebook/bart-large')
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def test_inference_masked_lm(self):

from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor, floats_tensor
from .test_modeling_common import ModelTesterMixin, ids_tensor


if is_torch_available():
Expand All @@ -498,6 +498,7 @@ def test_inference_masked_lm(self):
{{cookiecutter.camelcase_modelname}}Config,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}Tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
Expand Down Expand Up @@ -115,6 +116,7 @@
from .models.{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
Expand Down Expand Up @@ -209,6 +211,7 @@
{% else -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
Expand All @@ -232,10 +235,7 @@

# Below: "# Model for Causal LM mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM),
{% else -%}
{% endif -%}
# End.

# Below: "# Model for Masked LM mapping"
Expand Down Expand Up @@ -384,6 +384,7 @@
{% else -%}
"{{cookiecutter.camelcase_modelname}}Encoder",
"{{cookiecutter.camelcase_modelname}}Decoder",
"{{cookiecutter.camelcase_modelname}}DecoderWrapper",
{% endif -%}
# End.

Expand All @@ -393,5 +394,6 @@
{% else -%}
"{{cookiecutter.camelcase_modelname}}Encoder", # Building part of bigger (tested) model.
"{{cookiecutter.camelcase_modelname}}Decoder", # Building part of bigger (tested) model.
"{{cookiecutter.camelcase_modelname}}DecoderWrapper", # Building part of bigger (tested) model.
{% endif -%}
# End.
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ Tips:
:members: forward


{{cookiecutter.camelcase_modelname}}ForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.{{cookiecutter.camelcase_modelname}}ForCausalLM
:members: forward


{% endif -%}
{% endif -%}
{% if "TensorFlow" in cookiecutter.generate_tensorflow_and_pytorch -%}
Expand Down Expand Up @@ -180,5 +187,7 @@ TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration

.. autoclass:: transformers.TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
:members: call


{% endif -%}
{% endif -%}