Skip to content

Commit

Permalink
[TFBart] Split TF-Bart (#9497)
Browse files Browse the repository at this point in the history
* make templates ready

* make add_new_model_command_ready

* finish tf bart

* prepare tf mbart

* finish tf bart

* add tf mbart

* add marian

* prep pegasus

* add tf pegasus

* push blenderbot tf

* add blenderbot

* add blenderbot small

* clean-up

* make fix copy

* define blend bot tok

* fix

* up

* make style

* add to docs

* add copy statements

* overwrite changes

* improve

* fix docs

* finish

* fix last slow test

* fix missing git conflict line

* fix blenderbot

* up

* fix blenderbot small

* load changes

* finish copied from

* upload fix
  • Loading branch information
patrickvonplaten authored Jan 12, 2021
1 parent 0ecbb69 commit 7f28613
Show file tree
Hide file tree
Showing 39 changed files with 7,866 additions and 588 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Blenderbot ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BlenderbotSmall |||| ||
| BlenderbotSmall |||| ||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CTRL ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
Expand Down
11 changes: 8 additions & 3 deletions docs/source/model_doc/blenderbot.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,15 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward`
:members: forward


TFBlenderbotForConditionalGeneration
TFBlenderbotModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

See :obj:`transformers.TFBartForConditionalGeneration` for arguments to `forward` and `generate`
.. autoclass:: transformers.TFBlenderbotModel
:members: call


TFBlenderbotForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFBlenderbotForConditionalGeneration
:members:
:members: call
14 changes: 14 additions & 0 deletions docs/source/model_doc/blenderbot_small.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,17 @@ BlenderbotSmallForConditionalGeneration

.. autoclass:: transformers.BlenderbotSmallForConditionalGeneration
:members: forward


TFBlenderbotSmallModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFBlenderbotSmallModel
:members: call


TFBlenderbotSmallForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFBlenderbotSmallForConditionalGeneration
:members: call
8 changes: 8 additions & 0 deletions docs/source/model_doc/marian.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,15 @@ MarianMTModel
:members: forward


TFMarianModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFMarianModel
:members: call


TFMarianMTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFMarianMTModel
:members: call
9 changes: 8 additions & 1 deletion docs/source/model_doc/mbart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,15 @@ MBartForSequenceClassification
.. autoclass:: transformers.MBartForSequenceClassification


TFMBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFMBartModel
:members: call


TFMBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFMBartForConditionalGeneration
:members:
:members: call
8 changes: 8 additions & 0 deletions docs/source/model_doc/pegasus.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,15 @@ PegasusForConditionalGeneration
:members: forward


TFPegasusModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFPegasusModel
:members: call


TFPegasusForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFPegasusForConditionalGeneration
:members: call
20 changes: 12 additions & 8 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,10 @@
"TFBertPreTrainedModel",
]
)
_import_structure["models.blenderbot"].append("TFBlenderbotForConditionalGeneration")
_import_structure["models.blenderbot"].extend(["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"])
_import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel"]
)
_import_structure["models.camembert"].extend(
[
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -986,8 +989,8 @@
"TFLxmertVisualFeatureEncoder",
]
)
_import_structure["models.marian"].append("TFMarianMTModel")
_import_structure["models.mbart"].append("TFMBartForConditionalGeneration")
_import_structure["models.marian"].extend(["TFMarianMTModel", "TFMarianModel"])
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
_import_structure["models.mobilebert"].extend(
[
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1028,7 +1031,7 @@
"TFOpenAIGPTPreTrainedModel",
]
)
_import_structure["models.pegasus"].append("TFPegasusForConditionalGeneration")
_import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"])
_import_structure["models.roberta"].extend(
[
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1855,7 +1858,8 @@
TFBertModel,
TFBertPreTrainedModel,
)
from .models.blenderbot import TFBlenderbotForConditionalGeneration
from .models.blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from .models.blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForMaskedLM,
Expand Down Expand Up @@ -1953,8 +1957,8 @@
TFLxmertPreTrainedModel,
TFLxmertVisualFeatureEncoder,
)
from .models.marian import TFMarianMTModel
from .models.mbart import TFMBartForConditionalGeneration
from .models.marian import TFMarian, TFMarianMTModel
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
from .models.mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM,
Expand Down Expand Up @@ -1989,7 +1993,7 @@
TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel,
)
from .models.pegasus import TFPegasusForConditionalGeneration
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM,
Expand Down
20 changes: 16 additions & 4 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@
TFBertLMHeadModel,
TFBertModel,
)
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from ..blenderbot_small.modeling_tf_blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
)
from ..camembert.modeling_tf_camembert import (
TFCamembertForMaskedLM,
TFCamembertForMultipleChoice,
Expand Down Expand Up @@ -100,8 +104,8 @@
TFLongformerModel,
)
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
from ..marian.modeling_tf_marian import TFMarianMTModel
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration
from ..marian.modeling_tf_marian import TFMarianModel, TFMarianMTModel
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
from ..mobilebert.modeling_tf_mobilebert import (
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
Expand All @@ -122,7 +126,7 @@
)
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from ..roberta.modeling_tf_roberta import (
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
Expand Down Expand Up @@ -167,6 +171,7 @@
BartConfig,
BertConfig,
BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig,
CTRLConfig,
DistilBertConfig,
Expand Down Expand Up @@ -225,6 +230,12 @@
(FunnelConfig, TFFunnelModel),
(DPRConfig, TFDPRQuestionEncoder),
(MPNetConfig, TFMPNetModel),
(BartConfig, TFBartModel),
(MBartConfig, TFMBartModel),
(MarianConfig, TFMarianModel),
(PegasusConfig, TFPegasusModel),
(BlenderbotConfig, TFBlenderbotModel),
(BlenderbotSmallConfig, TFBlenderbotSmallModel),
]
)

Expand Down Expand Up @@ -328,6 +339,7 @@
(MBartConfig, TFMBartForConditionalGeneration),
(PegasusConfig, TFPegasusForConditionalGeneration),
(BlenderbotConfig, TFBlenderbotForConditionalGeneration),
(BlenderbotSmallConfig, TFBlenderbotSmallForConditionalGeneration),
(BartConfig, TFBartForConditionalGeneration),
]
)
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..bert.tokenization_bert import BertTokenizer
from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
from ..bertweet.tokenization_bertweet import BertweetTokenizer
from ..blenderbot.tokenization_blenderbot import BlenderbotTokenizer
from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer
from ..ctrl.tokenization_ctrl import CTRLTokenizer
from ..deberta.tokenization_deberta import DebertaTokenizer
Expand Down Expand Up @@ -58,6 +59,7 @@
BertConfig,
BertGenerationConfig,
BlenderbotConfig,
BlenderbotSmallConfig,
CamembertConfig,
CTRLConfig,
DebertaConfig,
Expand Down Expand Up @@ -201,7 +203,8 @@
(MBartConfig, (MBartTokenizer, MBartTokenizerFast)),
(XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)),
(MarianConfig, (MarianTokenizer, None)),
(BlenderbotConfig, (BlenderbotSmallTokenizer, None)),
(BlenderbotSmallConfig, (BlenderbotSmallTokenizer, None)),
(BlenderbotConfig, (BlenderbotTokenizer, None)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
(BartConfig, (BartTokenizer, BartTokenizerFast)),
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
Expand Down
10 changes: 0 additions & 10 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,6 @@ def __init__(
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated # only relevant for CNN

# IMPORTANT
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
self.extra_pos_embeddings = 2
self.normalize_before = False
self.add_final_layer_norm = False
self.do_blenderbot_90_layernorm = False
self.normalize_embedding = True
self.static_position_embeddings = False
self.add_bias_logits = False

@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
Expand Down
Loading

0 comments on commit 7f28613

Please sign in to comment.