From 68879472c4bd83a10e59d3cee58b61b2108b4971 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 25 Nov 2020 12:02:15 -0500 Subject: [PATCH] Big model table (#8774) * First draft * Styling * With all changes staged * Update docs/source/index.rst Co-authored-by: Julien Chaumond * Styling Co-authored-by: Julien Chaumond --- docs/source/_static/css/huggingface.css | 9 ++ docs/source/index.rst | 95 +++++++++++- src/transformers/__init__.py | 2 + src/transformers/models/auto/__init__.py | 7 +- .../models/auto/modeling_flax_auto.py | 10 +- src/transformers/utils/dummy_flax_objects.py | 12 ++ utils/check_copies.py | 140 ++++++++++++++++-- utils/check_repo.py | 1 + 8 files changed, 257 insertions(+), 19 deletions(-) diff --git a/docs/source/_static/css/huggingface.css b/docs/source/_static/css/huggingface.css index 9b31a2df673c..cee1aac5bc1d 100644 --- a/docs/source/_static/css/huggingface.css +++ b/docs/source/_static/css/huggingface.css @@ -2,6 +2,15 @@ /* Colab dropdown */ +table.center-aligned-table td { + text-align: center; +} + +table.center-aligned-table th { + text-align: center; + vertical-align: middle; +} + .colab-dropdown { position: relative; display: inline-block; diff --git a/docs/source/index.rst b/docs/source/index.rst index 4051ecbc8bdf..43dff49658e9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -35,6 +35,8 @@ Choose the right framework for every part of a model's lifetime: - Move a single model between TF2.0/PyTorch frameworks at will - Seamlessly pick the right framework for training, evaluation, production +Experimental support for Flax with a few models right now, expected to grow in the coming months. + Contents ----------------------------------------------------------------------------------------------------------------------- @@ -52,8 +54,8 @@ The documentation is organized in five parts: - **MODELS** for the classes and functions related to each model implemented in the library. - **INTERNAL HELPERS** for the classes and functions we use internally. -The library currently contains PyTorch and Tensorflow implementations, pre-trained model weights, usage scripts and -conversion utilities for the following models: +The library currently contains PyTorch, Tensorflow and Flax implementations, pretrained model weights, usage scripts +and conversion utilities for the following models: .. This list is updated automatically from the README with `make fix-copies`. Do not update manually! @@ -166,6 +168,95 @@ conversion utilities for the following models: 34. `Other community models `__, contributed by the `community `__. + +The table below represents the current support in the library for each of those models, whether they have a Python +tokenizer (called "slow"). A "fast" tokenizer backed by the 🤗 Tokenizers library, whether they have support in PyTorch, +TensorFlow and/or Flax. + +.. + This table is updated automatically from the auto modules with `make fix-copies`. Do not update manually! + +.. rst-class:: center-aligned-table + ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support | ++=============================+================+================+=================+====================+==============+ +| ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| BART | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| BERT | ✅ | ✅ | ✅ | ✅ | ✅ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Bert Generation | ✅ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| DPR | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| DeBERTa | ✅ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| DistilBERT | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| ELECTRA | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| FlauBERT | ✅ | ❌ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| LayoutLM | ✅ | ✅ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Marian | ✅ | ❌ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Pegasus | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| RAG | ✅ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| T5 | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| XLM | ✅ | ❌ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| mBART | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| mT5 | ✅ | ✅ | ✅ | ✅ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ + + .. toctree:: :maxdepth: 2 :caption: Get started diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a92fb488125b..f18e76eee723 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -98,6 +98,7 @@ from .models.auto import ( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, + MODEL_NAMES_MAPPING, TOKENIZER_MAPPING, AutoConfig, AutoTokenizer, @@ -876,6 +877,7 @@ if is_flax_available(): + from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel from .models.bert import FlaxBertModel from .models.roberta import FlaxRobertaModel else: diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 86ab29b89156..fb308c408768 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -2,8 +2,8 @@ # There's no way to ignore "F401 '...' imported but unused" warnings in this # module, but to preserve other warnings. So, don't check this module at all. -from ...file_utils import is_tf_available, is_torch_available -from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig +from ...file_utils import is_flax_available, is_tf_available, is_torch_available +from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer @@ -57,3 +57,6 @@ TFAutoModelForTokenClassification, TFAutoModelWithLMHead, ) + +if is_flax_available(): + from .modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index bc44f881128d..dab92814a772 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -36,7 +36,7 @@ for key, value, in pretrained_map.items() ) -MODEL_MAPPING = OrderedDict( +FLAX_MODEL_MAPPING = OrderedDict( [ (RobertaConfig, FlaxRobertaModel), (BertConfig, FlaxBertModel), @@ -79,13 +79,13 @@ def from_config(cls, config): model = FlaxAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` """ - for config_class, model_class in MODEL_MAPPING.items(): + for config_class, model_class in FLAX_MODEL_MAPPING.items(): if isinstance(config, config_class): return model_class(config) raise ValueError( f"Unrecognized configuration class {config.__class__} " f"for this kind of FlaxAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}." + f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}." ) @classmethod @@ -173,11 +173,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): if not isinstance(config, PretrainedConfig): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - for config_class, model_class in MODEL_MAPPING.items(): + for config_class, model_class in FLAX_MODEL_MAPPING.items(): if isinstance(config, config_class): return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) raise ValueError( f"Unrecognized configuration class {config.__class__} " f"for this kind of FlaxAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}" + f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}" ) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 77e932652def..84f9853842b9 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -2,6 +2,18 @@ from ..file_utils import requires_flax +FLAX_MODEL_MAPPING = None + + +class FlaxAutoModel: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + class FlaxBertModel: def __init__(self, *args, **kwargs): requires_flax(self) diff --git a/utils/check_copies.py b/utils/check_copies.py index dc1803ce5082..734d91e5559f 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -15,6 +15,7 @@ import argparse import glob +import importlib import os import re import tempfile @@ -250,20 +251,21 @@ def _rep_link(match): return "\n".join(result) -def check_model_list_copy(overwrite=False, max_per_line=119): - """ Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """ - _start_prompt = " This list is updated automatically from the README" - _end_prompt = ".. toctree::" - with open(os.path.join(PATH_TO_DOCS, "index.rst"), "r", encoding="utf-8", newline="\n") as f: +def _find_text_in_file(filename, start_prompt, end_prompt): + """ + Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty + lines. + """ + with open(filename, "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() - # Find the start of the list. + # Find the start prompt. start_index = 0 - while not lines[start_index].startswith(_start_prompt): + while not lines[start_index].startswith(start_prompt): start_index += 1 start_index += 1 end_index = start_index - while not lines[end_index].startswith(_end_prompt): + while not lines[end_index].startswith(end_prompt): end_index += 1 end_index -= 1 @@ -272,8 +274,16 @@ def check_model_list_copy(overwrite=False, max_per_line=119): while len(lines[end_index]) <= 1: end_index -= 1 end_index += 1 + return "".join(lines[start_index:end_index]), start_index, end_index, lines + - rst_list = "".join(lines[start_index:end_index]) +def check_model_list_copy(overwrite=False, max_per_line=119): + """ Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """ + rst_list, start_index, end_index, lines = _find_text_in_file( + filename=os.path.join(PATH_TO_DOCS, "index.rst"), + start_prompt=" This list is updated automatically from the README", + end_prompt="The table below represents the current support", + ) md_list = get_model_list() converted_list = convert_to_rst(md_list, max_per_line=max_per_line) @@ -283,7 +293,116 @@ def check_model_list_copy(overwrite=False, max_per_line=119): f.writelines(lines[:start_index] + [converted_list] + lines[end_index:]) else: raise ValueError( - "The model list in the README changed and the list in `index.rst` has not been updated. Run `make fix-copies` to fix this." + "The model list in the README changed and the list in `index.rst` has not been updated. Run " + "`make fix-copies` to fix this." + ) + + +def _center_text(text, width): + text_length = 2 if text == "✅" or text == "❌" else len(text) + left_indent = (width - text_length) // 2 + right_indent = width - text_length - left_indent + return " " * left_indent + text + " " * right_indent + + +def get_model_table_from_auto_modules(): + """Generates an up-to-date model table from the content of the auto modules.""" + # This is to make sure the transformers module imported is the one in the repo. + spec = importlib.util.spec_from_file_location( + "transformers", + os.path.join(TRANSFORMERS_PATH, "__init__.py"), + submodule_search_locations=[TRANSFORMERS_PATH], + ) + transformers = spec.loader.load_module() + + # Dictionary model names to config. + model_name_to_config = { + name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items() + } + # All tokenizer tuples. + tokenizers = { + name: transformers.TOKENIZER_MAPPING[config] + for name, config in model_name_to_config.items() + if config in transformers.TOKENIZER_MAPPING + } + # Model names that a slow/fast tokenizer. + has_slow_tokenizers = [name for name, tok in tokenizers.items() if tok[0] is not None] + has_fast_tokenizers = [name for name, tok in tokenizers.items() if tok[1] is not None] + + # Model names that have a PyTorch implementation. + has_pt_model = [name for name, config in model_name_to_config.items() if config in transformers.MODEL_MAPPING] + # Some of the GenerationModel don't have a base model. + has_pt_model.extend( + [ + name + for name, config in model_name_to_config.items() + if config in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING + ] + ) + # Special exception for RAG + has_pt_model.append("RAG") + + # Model names that have a TensorFlow implementation. + has_tf_model = [name for name, config in model_name_to_config.items() if config in transformers.TF_MODEL_MAPPING] + # Some of the GenerationModel don't have a base model. + has_tf_model.extend( + [ + name + for name, config in model_name_to_config.items() + if config in transformers.TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING + ] + ) + + # Model names that have a Flax implementation. + has_flax_model = [ + name for name, config in model_name_to_config.items() if config in transformers.FLAX_MODEL_MAPPING + ] + + # Let's build that table! + model_names = list(model_name_to_config.keys()) + model_names.sort() + columns = ["Model", "Tokenizer slow", "Tokenizer fast", "PyTorch support", "TensorFlow support", "Flax Support"] + # We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side). + widths = [len(c) + 2 for c in columns] + widths[0] = max([len(name) for name in model_names]) + 2 + + # Rst table per se + table = ".. rst-class:: center-aligned-table\n\n" + table += "+" + "+".join(["-" * w for w in widths]) + "+\n" + table += "|" + "|".join([_center_text(c, w) for c, w in zip(columns, widths)]) + "|\n" + table += "+" + "+".join(["=" * w for w in widths]) + "+\n" + + check = {True: "✅", False: "❌"} + for name in model_names: + line = [ + name, + check[name in has_slow_tokenizers], + check[name in has_fast_tokenizers], + check[name in has_pt_model], + check[name in has_tf_model], + check[name in has_flax_model], + ] + table += "|" + "|".join([_center_text(l, w) for l, w in zip(line, widths)]) + "|\n" + table += "+" + "+".join(["-" * w for w in widths]) + "+\n" + return table + + +def check_model_table(overwrite=False): + """ Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`. """ + current_table, start_index, end_index, lines = _find_text_in_file( + filename=os.path.join(PATH_TO_DOCS, "index.rst"), + start_prompt=" This table is updated automatically from the auto module", + end_prompt=".. toctree::", + ) + new_table = get_model_table_from_auto_modules() + + if current_table != new_table: + if overwrite: + with open(os.path.join(PATH_TO_DOCS, "index.rst"), "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines[:start_index] + [new_table] + lines[end_index:]) + else: + raise ValueError( + "The model table in the `index.rst` has not been updated. Run `make fix-copies` to fix this." ) @@ -293,3 +412,4 @@ def check_model_list_copy(overwrite=False, max_per_line=119): args = parser.parse_args() check_copies(args.fix_and_overwrite) + check_model_table(args.fix_and_overwrite) diff --git a/utils/check_repo.py b/utils/check_repo.py index 291101ec3e12..a1522dd9fcfa 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -126,6 +126,7 @@ def get_model_modules(): "modeling_outputs", "modeling_retribert", "modeling_utils", + "modeling_flax_auto", "modeling_flax_utils", "modeling_transfo_xl_utilities", "modeling_tf_auto",