Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFBart] Split TF-Bart #9497

Merged
merged 34 commits into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
80ff820
make templates ready
patrickvonplaten Jan 10, 2021
44bd521
make add_new_model_command_ready
patrickvonplaten Jan 10, 2021
c892c0c
finish tf bart
patrickvonplaten Jan 10, 2021
ec560e2
prepare tf mbart
patrickvonplaten Jan 10, 2021
22131dc
finish tf bart
patrickvonplaten Jan 10, 2021
8f3e8ae
add tf mbart
patrickvonplaten Jan 10, 2021
4b7319b
add marian
patrickvonplaten Jan 10, 2021
8d6a45f
prep pegasus
patrickvonplaten Jan 10, 2021
ea6ac7d
add tf pegasus
patrickvonplaten Jan 10, 2021
7b6710c
push blenderbot tf
patrickvonplaten Jan 10, 2021
f2422c8
add blenderbot
patrickvonplaten Jan 10, 2021
e3649f1
add blenderbot small
patrickvonplaten Jan 10, 2021
303af5d
clean-up
patrickvonplaten Jan 10, 2021
b91962d
make fix copy
patrickvonplaten Jan 10, 2021
69ded7d
define blend bot tok
patrickvonplaten Jan 10, 2021
64e2642
fix
patrickvonplaten Jan 10, 2021
5a61a71
up
patrickvonplaten Jan 11, 2021
38da04c
make style
patrickvonplaten Jan 11, 2021
d7689bd
add to docs
patrickvonplaten Jan 11, 2021
7d6e40d
add copy statements
patrickvonplaten Jan 11, 2021
1b2bc4f
overwrite changes
patrickvonplaten Jan 11, 2021
3dfc371
improve
patrickvonplaten Jan 11, 2021
37f4578
fix docs
patrickvonplaten Jan 11, 2021
b6514fe
finish
patrickvonplaten Jan 11, 2021
97987c9
fix last slow test
patrickvonplaten Jan 11, 2021
3b76392
Merge branch 'split_tf_bart' of https://github.com/patrickvonplaten/t…
patrickvonplaten Jan 11, 2021
09689bb
merge conficlts
patrickvonplaten Jan 11, 2021
7c8c807
fix missing git conflict line
patrickvonplaten Jan 11, 2021
2366e24
fix blenderbot
patrickvonplaten Jan 11, 2021
4822ec3
up
patrickvonplaten Jan 11, 2021
3537503
fix blenderbot small
patrickvonplaten Jan 11, 2021
58800bf
load changes
patrickvonplaten Jan 12, 2021
273733f
finish copied from
patrickvonplaten Jan 12, 2021
a363188
upload fix
patrickvonplaten Jan 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Blenderbot ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BlenderbotSmall |||| ||
| BlenderbotSmall |||| ||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CTRL ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
11 changes: 8 additions & 3 deletions docs/source/model_doc/blenderbot.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,15 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward`
:members: forward


TFBlenderbotForConditionalGeneration
TFBlenderbotModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

See :obj:`transformers.TFBartForConditionalGeneration` for arguments to `forward` and `generate`
.. autoclass:: transformers.TFBlenderbotModel
:members: call


TFBlenderbotForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFBlenderbotForConditionalGeneration
:members:
:members: call
14 changes: 14 additions & 0 deletions docs/source/model_doc/blenderbot_small.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,17 @@ BlenderbotSmallForConditionalGeneration

.. autoclass:: transformers.BlenderbotSmallForConditionalGeneration
:members: forward


TFBlenderbotSmallModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFBlenderbotSmallModel
:members: call


TFBlenderbotSmallForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFBlenderbotSmallForConditionalGeneration
:members: call
8 changes: 8 additions & 0 deletions docs/source/model_doc/marian.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,15 @@ MarianMTModel
:members: forward


TFMarianModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFMarianModel
:members: call


TFMarianMTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFMarianMTModel
:members: call
9 changes: 8 additions & 1 deletion docs/source/model_doc/mbart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,15 @@ MBartForSequenceClassification
.. autoclass:: transformers.MBartForSequenceClassification


TFMBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFMBartModel
:members: call


TFMBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFMBartForConditionalGeneration
:members:
:members: call
8 changes: 8 additions & 0 deletions docs/source/model_doc/pegasus.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,15 @@ PegasusForConditionalGeneration
:members: forward


TFPegasusModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFPegasusModel
:members: call


TFPegasusForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFPegasusForConditionalGeneration
:members: call
20 changes: 12 additions & 8 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,10 @@
"TFBertPreTrainedModel",
]
)
_import_structure["models.blenderbot"].append("TFBlenderbotForConditionalGeneration")
_import_structure["models.blenderbot"].extend(["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"])
_import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel"]
)
_import_structure["models.camembert"].extend(
[
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -986,8 +989,8 @@
"TFLxmertVisualFeatureEncoder",
]
)
_import_structure["models.marian"].append("TFMarianMTModel")
_import_structure["models.mbart"].append("TFMBartForConditionalGeneration")
_import_structure["models.marian"].extend(["TFMarianMTModel", "TFMarianModel"])
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
_import_structure["models.mobilebert"].extend(
[
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1028,7 +1031,7 @@
"TFOpenAIGPTPreTrainedModel",
]
)
_import_structure["models.pegasus"].append("TFPegasusForConditionalGeneration")
_import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"])
_import_structure["models.roberta"].extend(
[
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1855,7 +1858,8 @@
TFBertModel,
TFBertPreTrainedModel,
)
from .models.blenderbot import TFBlenderbotForConditionalGeneration
from .models.blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from .models.blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForMaskedLM,
Expand Down Expand Up @@ -1953,8 +1957,8 @@
TFLxmertPreTrainedModel,
TFLxmertVisualFeatureEncoder,
)
from .models.marian import TFMarianMTModel
from .models.mbart import TFMBartForConditionalGeneration
from .models.marian import TFMarian, TFMarianMTModel
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
from .models.mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM,
Expand Down Expand Up @@ -1989,7 +1993,7 @@
TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel,
)
from .models.pegasus import TFPegasusForConditionalGeneration
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM,
Expand Down
20 changes: 16 additions & 4 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@
TFBertLMHeadModel,
TFBertModel,
)
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from ..blenderbot_small.modeling_tf_blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
)
from ..camembert.modeling_tf_camembert import (
TFCamembertForMaskedLM,
TFCamembertForMultipleChoice,
Expand Down Expand Up @@ -100,8 +104,8 @@
TFLongformerModel,
)
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
from ..marian.modeling_tf_marian import TFMarianMTModel
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration
from ..marian.modeling_tf_marian import TFMarianModel, TFMarianMTModel
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
from ..mobilebert.modeling_tf_mobilebert import (
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
Expand All @@ -122,7 +126,7 @@
)
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from ..roberta.modeling_tf_roberta import (
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
Expand Down Expand Up @@ -167,6 +171,7 @@
BartConfig,
BertConfig,
BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig,
CTRLConfig,
DistilBertConfig,
Expand Down Expand Up @@ -225,6 +230,12 @@
(FunnelConfig, TFFunnelModel),
(DPRConfig, TFDPRQuestionEncoder),
(MPNetConfig, TFMPNetModel),
(BartConfig, TFBartModel),
(MBartConfig, TFMBartModel),
(MarianConfig, TFMarianModel),
(PegasusConfig, TFPegasusModel),
(BlenderbotConfig, TFBlenderbotModel),
(BlenderbotSmallConfig, TFBlenderbotSmallModel),
]
)

Expand Down Expand Up @@ -328,6 +339,7 @@
(MBartConfig, TFMBartForConditionalGeneration),
(PegasusConfig, TFPegasusForConditionalGeneration),
(BlenderbotConfig, TFBlenderbotForConditionalGeneration),
(BlenderbotSmallConfig, TFBlenderbotSmallForConditionalGeneration),
(BartConfig, TFBartForConditionalGeneration),
]
)
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..bert.tokenization_bert import BertTokenizer
from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
from ..bertweet.tokenization_bertweet import BertweetTokenizer
from ..blenderbot.tokenization_blenderbot import BlenderbotTokenizer
from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer
from ..ctrl.tokenization_ctrl import CTRLTokenizer
from ..deberta.tokenization_deberta import DebertaTokenizer
Expand Down Expand Up @@ -58,6 +59,7 @@
BertConfig,
BertGenerationConfig,
BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig,
CTRLConfig,
DebertaConfig,
Expand Down Expand Up @@ -201,7 +203,8 @@
(MBartConfig, (MBartTokenizer, MBartTokenizerFast)),
(XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)),
(MarianConfig, (MarianTokenizer, None)),
(BlenderbotConfig, (BlenderbotSmallTokenizer, None)),
(BlenderbotSmallConfig, (BlenderbotSmallTokenizer, None)),
(BlenderbotConfig, (BlenderbotTokenizer, None)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(BartConfig, (BartTokenizer, BartTokenizerFast)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
Expand Down
10 changes: 0 additions & 10 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,6 @@ def __init__(
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated # only relevant for CNN

# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 2
self.normalize_before = False
self.add_final_layer_norm = False
self.do_blenderbot_90_layernorm = False
self.normalize_embedding = True
self.static_position_embeddings = False
self.add_bias_logits = False

@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
Expand Down
Loading