From 16f94797b5f10ada5eeea413ed551038cbf5bbaf Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 18 Mar 2021 20:09:16 -0400 Subject: [PATCH 1/6] Initial script --- src/transformers/__init__.py | 166 ++++++++++++------------ utils/custom_init_isort.py | 238 +++++++++++++++++++++++++++++++++++ 2 files changed, 321 insertions(+), 83 deletions(-) create mode 100644 utils/custom_init_isort.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 57854cbefcb0..84e4ea26c726 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -78,6 +78,7 @@ "xnli_processors", "xnli_tasks_num_labels", ], + "feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"], "file_utils": [ "CONFIG_NAME", "MODEL_CARD_NAME", @@ -124,23 +125,8 @@ "load_tf2_model_in_pytorch_model", "load_tf2_weights_in_pytorch_model", ], - "models": [], # Models - "models.wav2vec2": [ - "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", - "Wav2Vec2Config", - "Wav2Vec2CTCTokenizer", - "Wav2Vec2Tokenizer", - "Wav2Vec2FeatureExtractor", - "Wav2Vec2Processor", - ], - "models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], - "models.speech_to_text": [ - "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", - "Speech2TextConfig", - "Speech2TextFeatureExtractor", - ], - "models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"], + "models": [], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.auto": [ "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -169,6 +155,7 @@ "BlenderbotSmallTokenizer", ], "models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"], + "models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"], "models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"], "models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"], "models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"], @@ -193,6 +180,7 @@ "models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"], "models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"], "models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"], + "models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], "models.marian": ["MarianConfig"], "models.mbart": ["MBartConfig"], "models.mmbt": ["MMBTConfig"], @@ -207,6 +195,11 @@ "models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"], "models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"], "models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"], + "models.speech_to_text": [ + "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "Speech2TextConfig", + "Speech2TextFeatureExtractor", + ], "models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"], "models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"], "models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"], @@ -216,6 +209,14 @@ "TransfoXLCorpus", "TransfoXLTokenizer", ], + "models.wav2vec2": [ + "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", + "Wav2Vec2Config", + "Wav2Vec2CTCTokenizer", + "Wav2Vec2FeatureExtractor", + "Wav2Vec2Processor", + "Wav2Vec2Tokenizer", + ], "models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"], "models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"], "models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"], @@ -251,7 +252,6 @@ "SpecialTokensMixin", "TokenSpan", ], - "feature_extraction_sequence_utils": ["SequenceFeatureExtractor", "BatchFeature"], "trainer_callback": [ "DefaultFlowCallback", "EarlyStoppingCallback", @@ -383,54 +383,15 @@ "TopPLogitsWarper", ] _import_structure["generation_stopping_criteria"] = [ - "StoppingCriteria", - "StoppingCriteriaList", "MaxLengthCriteria", "MaxTimeCriteria", + "StoppingCriteria", + "StoppingCriteriaList", ] _import_structure["generation_utils"] = ["top_k_top_p_filtering"] _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] # PyTorch models structure - _import_structure["models.speech_to_text"].extend( - [ - "SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", - "Speech2TextForConditionalGeneration", - "Speech2TextModel", - ] - ) - - _import_structure["models.wav2vec2"].extend( - [ - "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", - "Wav2Vec2ForCTC", - "Wav2Vec2ForMaskedLM", - "Wav2Vec2Model", - "Wav2Vec2PreTrainedModel", - ] - ) - _import_structure["models.m2m_100"].extend( - [ - "M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST", - "M2M100ForConditionalGeneration", - "M2M100Model", - ] - ) - - _import_structure["models.convbert"].extend( - [ - "CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST", - "ConvBertForMaskedLM", - "ConvBertForMultipleChoice", - "ConvBertForQuestionAnswering", - "ConvBertForSequenceClassification", - "ConvBertForTokenClassification", - "ConvBertLayer", - "ConvBertModel", - "ConvBertPreTrainedModel", - "load_tf_weights_in_convbert", - ] - ) _import_structure["models.albert"].extend( [ "ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -445,6 +406,7 @@ "load_tf_weights_in_albert", ] ) + _import_structure["models.auto"].extend( [ "MODEL_FOR_CAUSAL_LM_MAPPING", @@ -485,6 +447,7 @@ "PretrainedBartModel", ] ) + _import_structure["models.bert"].extend( [ "BERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -512,17 +475,17 @@ _import_structure["models.blenderbot"].extend( [ "BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST", + "BlenderbotForCausalLM", "BlenderbotForConditionalGeneration", "BlenderbotModel", - "BlenderbotForCausalLM", ] ) _import_structure["models.blenderbot_small"].extend( [ "BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST", + "BlenderbotSmallForCausalLM", "BlenderbotSmallForConditionalGeneration", "BlenderbotSmallModel", - "BlenderbotSmallForCausalLM", ] ) _import_structure["models.camembert"].extend( @@ -537,6 +500,20 @@ "CamembertModel", ] ) + _import_structure["models.convbert"].extend( + [ + "CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "ConvBertForMaskedLM", + "ConvBertForMultipleChoice", + "ConvBertForQuestionAnswering", + "ConvBertForSequenceClassification", + "ConvBertForTokenClassification", + "ConvBertLayer", + "ConvBertModel", + "ConvBertPreTrainedModel", + "load_tf_weights_in_convbert", + ] + ) _import_structure["models.ctrl"].extend( [ "CTRL_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -549,23 +526,23 @@ _import_structure["models.deberta"].extend( [ "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "DebertaForMaskedLM", + "DebertaForQuestionAnswering", "DebertaForSequenceClassification", + "DebertaForTokenClassification", "DebertaModel", - "DebertaForMaskedLM", "DebertaPreTrainedModel", - "DebertaForTokenClassification", - "DebertaForQuestionAnswering", ] ) _import_structure["models.deberta_v2"].extend( [ "DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST", + "DebertaV2ForMaskedLM", + "DebertaV2ForQuestionAnswering", "DebertaV2ForSequenceClassification", + "DebertaV2ForTokenClassification", "DebertaV2Model", - "DebertaV2ForMaskedLM", "DebertaV2PreTrainedModel", - "DebertaV2ForTokenClassification", - "DebertaV2ForQuestionAnswering", ] ) _import_structure["models.distilbert"].extend( @@ -699,7 +676,14 @@ "LxmertXLayer", ] ) - _import_structure["models.marian"].extend(["MarianModel", "MarianMTModel", "MarianForCausalLM"]) + _import_structure["models.m2m_100"].extend( + [ + "M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST", + "M2M100ForConditionalGeneration", + "M2M100Model", + ] + ) + _import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"]) _import_structure["models.mbart"].extend( [ "MBartForCausalLM", @@ -752,7 +736,7 @@ ] ) _import_structure["models.pegasus"].extend( - ["PegasusForConditionalGeneration", "PegasusModel", "PegasusForCausalLM"] + ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel"] ) _import_structure["models.prophetnet"].extend( [ @@ -793,6 +777,13 @@ "RobertaModel", ] ) + _import_structure["models.speech_to_text"].extend( + [ + "SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", + "Speech2TextForConditionalGeneration", + "Speech2TextModel", + ] + ) _import_structure["models.squeezebert"].extend( [ "SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -836,6 +827,15 @@ "load_tf_weights_in_transfo_xl", ] ) + _import_structure["models.wav2vec2"].extend( + [ + "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", + "Wav2Vec2ForCTC", + "Wav2Vec2ForMaskedLM", + "Wav2Vec2Model", + "Wav2Vec2PreTrainedModel", + ] + ) _import_structure["models.xlm"].extend( [ "XLM_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -917,19 +917,6 @@ ] # TensorFlow models structure - _import_structure["models.convbert"].extend( - [ - "TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST", - "TFConvBertForMaskedLM", - "TFConvBertForMultipleChoice", - "TFConvBertForQuestionAnswering", - "TFConvBertForSequenceClassification", - "TFConvBertForTokenClassification", - "TFConvBertLayer", - "TFConvBertModel", - "TFConvBertPreTrainedModel", - ] - ) _import_structure["models.albert"].extend( [ "TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1002,6 +989,19 @@ "TFCamembertModel", ] ) + _import_structure["models.convbert"].extend( + [ + "TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFConvBertForMaskedLM", + "TFConvBertForMultipleChoice", + "TFConvBertForQuestionAnswering", + "TFConvBertForSequenceClassification", + "TFConvBertForTokenClassification", + "TFConvBertLayer", + "TFConvBertModel", + "TFConvBertPreTrainedModel", + ] + ) _import_structure["models.ctrl"].extend( [ "TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1108,7 +1108,7 @@ "TFLxmertVisualFeatureEncoder", ] ) - _import_structure["models.marian"].extend(["TFMarianMTModel", "TFMarianModel"]) + _import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel"]) _import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"]) _import_structure["models.mobilebert"].extend( [ @@ -2170,7 +2170,7 @@ TFLxmertPreTrainedModel, TFLxmertVisualFeatureEncoder, ) - from .models.marian import TFMarian, TFMarianMTModel + from .models.marian import TFMarianModel, TFMarianMTModel from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel from .models.mobilebert import ( TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py new file mode 100644 index 000000000000..9ecac8fb08ff --- /dev/null +++ b/utils/custom_init_isort.py @@ -0,0 +1,238 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re + + +PATH_TO_TRANSFORMERS = "src/transformers" + +# Pattern that looks at the indentation in a line. +_re_indent = re.compile(r"^(\s*)\S") +# Pattern that matches `"key":" and puts `key` in group 0. +_re_direct_key = re.compile(r'^\s*"([^"]+)":') +# Pattern that matches `_import_structure["key"]` and puts `key` in group 0. +_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]') +# Pattern that matches `"key",` and puts `key` in group 0. +_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$') +# Pattern that matches any `[stuff]` and puts `stuff` in group 0. +_re_bracket_content = re.compile(r"\[([^\]]+)\]") + + +def get_indent(line): + """Returns the indent in `line`.""" + search = _re_indent.search(line) + return "" if search is None else search.groups()[0] + + +def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None): + """ + Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after + `start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's + after `end_prompt` as a last block, so `code` is always the same as joining the result of this function). + """ + # Let's split the code into lines and move to start_index. + index = 0 + lines = code.split("\n") + if start_prompt is not None: + while not lines[index].startswith(start_prompt): + index += 1 + blocks = ["\n".join(lines[:index])] + else: + blocks = [] + + # We split into blocks until we get to the `end_prompt` (or the end of the block). + current_block = [lines[index]] + index += 1 + while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)): + if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level: + if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "): + current_block.append(lines[index]) + blocks.append("\n".join(current_block)) + if index < len(lines) - 2: + current_block = [lines[index + 1]] + index += 1 + else: + current_block = [] + else: + blocks.append("\n".join(current_block)) + current_block = [lines[index]] + else: + current_block.append(lines[index]) + index += 1 + + # Adds current block if it's nonempty. + if len(current_block) > 0: + blocks.append("\n".join(current_block)) + + # Add final block after end_prompt if provided. + if end_prompt is not None and index < len(lines): + blocks.append("\n".join(lines[index:])) + + return blocks + + +def ignore_underscore(key): + "Wraps a `key` (that maps an object to string) to lower case and remove underscores." + + def _inner(x): + return key(x).lower().replace("_", "") + + return _inner + + +def sort_objects(objects, key=None): + "Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str." + # If no key is provided, we use a noop. + if key is None: + key = lambda x: x + # Constants are all uppercase, they go first. + constants = [obj for obj in objects if key(obj).isupper()] + # Classes are not all uppercase but start with a capital, they go second. + classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()] + # Functions begin with a lowercase, they go last. + functions = [obj for obj in objects if not key(obj)[0].isupper()] + + key1 = ignore_underscore(key) + return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1) + + +def sort_objects_in_import(import_statement): + """ + Return the same `import_statement` but with objects properly sorted. + """ + # This inner function sort imports between [ ]. + def _replace(match): + imports = match.groups()[0] + if "," not in imports: + return f"[{imports}]" + keys = [part.strip().replace('"', "") for part in imports.split(",")] + # We will have a final empty element if the line finished with a comma. + if len(keys[-1]) == 0: + keys = keys[:-1] + return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]" + + lines = import_statement.split("\n") + if len(lines) > 3: + # Here we have to sort internal imports that are on several lines (one per name): + # key: [ + # "object1", + # "object2", + # ... + # ] + + # We may have to ignore one or two lines on each side. + idx = 2 if lines[1].strip() == "[" else 1 + keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])] + sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1]) + sorted_lines = [lines[x[0] + idx] for x in sorted_indices] + return "\n".join(lines[:idx] + sorted_lines + lines[-idx:]) + elif len(lines) == 3: + # Here we have to sort internal imports that are on one separate line: + # key: [ + # "object1", "object2", ... + # ] + if _re_bracket_content.search(lines[1]) is not None: + lines[1] = _re_bracket_content.sub(_replace, lines[1]) + else: + keys = [part.strip().replace('"', "") for part in lines[1].split(",")] + # We will have a final empty element if the line finished with a comma. + if len(keys[-1]) == 0: + keys = keys[:-1] + lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + return "\n".join(lines) + else: + # Finally we have to deal with imports fitting on one line + import_statement = _re_bracket_content.sub(_replace, import_statement) + return import_statement + + +def sort_imports(file, check_only=True): + """ + Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite. + """ + with open(file, "r") as f: + code = f.read() + + if "_import_structure" not in code: + return + + # Blocks of indent level 0 + main_blocks = split_code_in_indented_blocks( + code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:" + ) + + # We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt). + for block_idx in range(1, len(main_blocks) - 1): + # Check if the block contains some `_import_structure`s thingy to sort. + block = main_blocks[block_idx] + block_lines = block.split("\n") + if len(block_lines) < 3 or "_import_structure" not in "".join(block_lines[:2]): + continue + + # Ignore first and last line: they don't contain anything. + internal_block_code = "\n".join(block_lines[1:-1]) + indent = get_indent(block_lines[1]) + # Slit the internal block into blocks of indent level 1. + internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent) + # We have two categories of import key: list or _import_structu[key].append/extend + pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key + # Grab the keys, but there is a trap: some lines are empty or jsut comments. + keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks] + # We only sort the lines with a key. + keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None] + sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])] + + # We reorder the blocks by leaving empty lines/comments as they were and reorder the rest. + count = 0 + reorderded_blocks = [] + for i in range(len(internal_blocks)): + if keys[i] is None: + reorderded_blocks.append(internal_blocks[i]) + else: + block = sort_objects_in_import(internal_blocks[sorted_indices[count]]) + reorderded_blocks.append(block) + count += 1 + + # And we put our main block back together with its first and last line. + main_blocks[block_idx] = "\n".join([block_lines[0]] + reorderded_blocks + [block_lines[-1]]) + + if code != "\n".join(main_blocks): + if check_only: + return True + else: + print(f"Overwriting {file}.") + with open(file, "w") as f: + f.write("\n".join(main_blocks)) + + +def sort_imports_in_all_inits(check_only=True): + failures = [] + for root, _, files in os.walk(PATH_TO_TRANSFORMERS): + if "__init__.py" in files: + result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only) + if result: + failures = [os.path.join(root, "__init__.py")] + if len(failures) > 0: + raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.") + args = parser.parse_args() + + sort_imports_in_all_inits(check_only=args.check_only) From ce520faa5cee3de2d1037ee6b0a4c0b03a7c73ae Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 18 Mar 2021 20:19:30 -0400 Subject: [PATCH 2/6] Add script to properly sort imports in init. --- Makefile | 5 +++-- src/transformers/models/blenderbot/__init__.py | 2 +- src/transformers/models/blenderbot_small/__init__.py | 2 +- src/transformers/models/deberta/__init__.py | 6 +++--- src/transformers/models/deberta_v2/__init__.py | 6 +++--- src/transformers/models/ibert/__init__.py | 2 +- src/transformers/models/marian/__init__.py | 4 ++-- src/transformers/models/mbart/__init__.py | 2 +- src/transformers/models/pegasus/__init__.py | 2 +- src/transformers/models/speech_to_text/__init__.py | 3 +-- src/transformers/models/wav2vec2/__init__.py | 4 ++-- src/transformers/utils/dummy_tf_objects.py | 6 +++++- utils/custom_init_isort.py | 11 +++++++---- 13 files changed, 31 insertions(+), 24 deletions(-) diff --git a/Makefile b/Makefile index 7974335c14d2..50f3c14bae78 100644 --- a/Makefile +++ b/Makefile @@ -26,15 +26,15 @@ extra_quality_checks: deps_table_update python utils/check_table.py python utils/check_dummies.py python utils/check_repo.py - python utils/style_doc.py src/transformers docs/source --max_len 119 + python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only python utils/class_mapping_update.py # this target runs checks on all files quality: black --check $(check_dirs) isort --check-only $(check_dirs) + python utils/custom_init_isort.py --check_only flake8 $(check_dirs) - python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only ${MAKE} extra_quality_checks # Format source code automatically and check is there are any problems left that need manual fixing @@ -42,6 +42,7 @@ quality: style: deps_table_update black $(check_dirs) isort $(check_dirs) + python utils/custom_init_isort.py python utils/style_doc.py src/transformers docs/source --max_len 119 # Super fast fix and check target that only works on relevant modified files since the branch was made diff --git a/src/transformers/models/blenderbot/__init__.py b/src/transformers/models/blenderbot/__init__.py index cd46ae57036a..daf0b3dc4ed4 100644 --- a/src/transformers/models/blenderbot/__init__.py +++ b/src/transformers/models/blenderbot/__init__.py @@ -29,10 +29,10 @@ if is_torch_available(): _import_structure["modeling_blenderbot"] = [ "BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST", + "BlenderbotForCausalLM", "BlenderbotForConditionalGeneration", "BlenderbotModel", "BlenderbotPreTrainedModel", - "BlenderbotForCausalLM", ] diff --git a/src/transformers/models/blenderbot_small/__init__.py b/src/transformers/models/blenderbot_small/__init__.py index 2f60bc77c098..a40ab18ff1b8 100644 --- a/src/transformers/models/blenderbot_small/__init__.py +++ b/src/transformers/models/blenderbot_small/__init__.py @@ -28,10 +28,10 @@ if is_torch_available(): _import_structure["modeling_blenderbot_small"] = [ "BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST", + "BlenderbotSmallForCausalLM", "BlenderbotSmallForConditionalGeneration", "BlenderbotSmallModel", "BlenderbotSmallPreTrainedModel", - "BlenderbotSmallForCausalLM", ] if is_tf_available(): diff --git a/src/transformers/models/deberta/__init__.py b/src/transformers/models/deberta/__init__.py index 2a489b124033..ff9b6274f17b 100644 --- a/src/transformers/models/deberta/__init__.py +++ b/src/transformers/models/deberta/__init__.py @@ -29,12 +29,12 @@ if is_torch_available(): _import_structure["modeling_deberta"] = [ "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "DebertaForMaskedLM", + "DebertaForQuestionAnswering", "DebertaForSequenceClassification", + "DebertaForTokenClassification", "DebertaModel", - "DebertaForMaskedLM", "DebertaPreTrainedModel", - "DebertaForTokenClassification", - "DebertaForQuestionAnswering", ] diff --git a/src/transformers/models/deberta_v2/__init__.py b/src/transformers/models/deberta_v2/__init__.py index 6783455cf634..236c7dc9fc35 100644 --- a/src/transformers/models/deberta_v2/__init__.py +++ b/src/transformers/models/deberta_v2/__init__.py @@ -29,12 +29,12 @@ if is_torch_available(): _import_structure["modeling_deberta_v2"] = [ "DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST", + "DebertaV2ForMaskedLM", + "DebertaV2ForQuestionAnswering", "DebertaV2ForSequenceClassification", + "DebertaV2ForTokenClassification", "DebertaV2Model", - "DebertaV2ForMaskedLM", "DebertaV2PreTrainedModel", - "DebertaV2ForTokenClassification", - "DebertaV2ForQuestionAnswering", ] diff --git a/src/transformers/models/ibert/__init__.py b/src/transformers/models/ibert/__init__.py index af1df0b80a87..c43ad8e6d0a4 100644 --- a/src/transformers/models/ibert/__init__.py +++ b/src/transformers/models/ibert/__init__.py @@ -28,13 +28,13 @@ if is_torch_available(): _import_structure["modeling_ibert"] = [ "IBERT_PRETRAINED_MODEL_ARCHIVE_LIST", - "IBertPreTrainedModel", "IBertForMaskedLM", "IBertForMultipleChoice", "IBertForQuestionAnswering", "IBertForSequenceClassification", "IBertForTokenClassification", "IBertModel", + "IBertPreTrainedModel", ] if TYPE_CHECKING: diff --git a/src/transformers/models/marian/__init__.py b/src/transformers/models/marian/__init__.py index 34a35922c84f..4ec04e192a6c 100644 --- a/src/transformers/models/marian/__init__.py +++ b/src/transformers/models/marian/__init__.py @@ -36,14 +36,14 @@ if is_torch_available(): _import_structure["modeling_marian"] = [ "MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST", + "MarianForCausalLM", "MarianModel", "MarianMTModel", "MarianPreTrainedModel", - "MarianForCausalLM", ] if is_tf_available(): - _import_structure["modeling_tf_marian"] = ["TFMarianMTModel", "TFMarianModel"] + _import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"] if TYPE_CHECKING: diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py index ed4856c45177..3367c3c43ba2 100644 --- a/src/transformers/models/mbart/__init__.py +++ b/src/transformers/models/mbart/__init__.py @@ -35,8 +35,8 @@ _import_structure["tokenization_mbart50"] = ["MBart50Tokenizer"] if is_tokenizers_available(): - _import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"] _import_structure["tokenization_mbart50_fast"] = ["MBart50TokenizerFast"] + _import_structure["tokenization_mbart_fast"] = ["MBartTokenizerFast"] if is_torch_available(): _import_structure["modeling_mbart"] = [ diff --git a/src/transformers/models/pegasus/__init__.py b/src/transformers/models/pegasus/__init__.py index 50e6284be863..daecd7825b4a 100644 --- a/src/transformers/models/pegasus/__init__.py +++ b/src/transformers/models/pegasus/__init__.py @@ -39,10 +39,10 @@ if is_torch_available(): _import_structure["modeling_pegasus"] = [ "PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST", + "PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel", - "PegasusForCausalLM", ] if is_tf_available(): diff --git a/src/transformers/models/speech_to_text/__init__.py b/src/transformers/models/speech_to_text/__init__.py index d431ce4fa6d6..0defd14c0032 100644 --- a/src/transformers/models/speech_to_text/__init__.py +++ b/src/transformers/models/speech_to_text/__init__.py @@ -29,9 +29,8 @@ } if is_sentencepiece_available(): - _import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"] _import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"] - + _import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"] if is_torch_available(): _import_structure["modeling_speech_to_text"] = [ diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py index 37456c17aa5f..183f85b82d3a 100644 --- a/src/transformers/models/wav2vec2/__init__.py +++ b/src/transformers/models/wav2vec2/__init__.py @@ -22,16 +22,16 @@ _import_structure = { "configuration_wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config"], - "tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"], "feature_extraction_wav2vec2": ["Wav2Vec2FeatureExtractor"], "processing_wav2vec2": ["Wav2Vec2Processor"], + "tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"], } if is_torch_available(): _import_structure["modeling_wav2vec2"] = [ "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", - "Wav2Vec2ForMaskedLM", "Wav2Vec2ForCTC", + "Wav2Vec2ForMaskedLM", "Wav2Vec2Model", "Wav2Vec2PreTrainedModel", ] diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index e6080a864280..baa20328edf1 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1050,10 +1050,14 @@ def __init__(self, *args, **kwargs): requires_tf(self) -class TFMarian: +class TFMarianModel: def __init__(self, *args, **kwargs): requires_tf(self) + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tf(self) + class TFMarianMTModel: def __init__(self, *args, **kwargs): diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py index 9ecac8fb08ff..1891623f283f 100644 --- a/utils/custom_init_isort.py +++ b/utils/custom_init_isort.py @@ -62,7 +62,7 @@ def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_ if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "): current_block.append(lines[index]) blocks.append("\n".join(current_block)) - if index < len(lines) - 2: + if index < len(lines) - 1: current_block = [lines[index + 1]] index += 1 else: @@ -97,8 +97,11 @@ def _inner(x): def sort_objects(objects, key=None): "Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str." # If no key is provided, we use a noop. + def noop(x): + return x + if key is None: - key = lambda x: x + key = noop # Constants are all uppercase, they go first. constants = [obj for obj in objects if key(obj).isupper()] # Classes are not all uppercase but start with a capital, they go second. @@ -166,10 +169,10 @@ def sort_imports(file, check_only=True): """ with open(file, "r") as f: code = f.read() - + if "_import_structure" not in code: return - + # Blocks of indent level 0 main_blocks = split_code_in_indented_blocks( code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:" From ae154f634701a2ad8fc0431c542bc127407ff783 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 18 Mar 2021 20:23:03 -0400 Subject: [PATCH 3/6] Add to the CI --- .circleci/config.yml | 1 + src/transformers/__init__.py | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f8040e7553f7..342c538bc1b5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -383,6 +383,7 @@ jobs: - '~/.cache/pip' - run: black --check examples tests src utils - run: isort --check-only examples tests src utils + - run: python utils/custom_init_isort.py --check_only - run: flake8 examples tests src utils - run: python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only - run: python utils/check_copies.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 84e4ea26c726..5d8aa3e427bb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -391,7 +391,6 @@ _import_structure["generation_utils"] = ["top_k_top_p_filtering"] _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] # PyTorch models structure - _import_structure["models.albert"].extend( [ "ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -406,7 +405,6 @@ "load_tf_weights_in_albert", ] ) - _import_structure["models.auto"].extend( [ "MODEL_FOR_CAUSAL_LM_MAPPING", @@ -447,7 +445,6 @@ "PretrainedBartModel", ] ) - _import_structure["models.bert"].extend( [ "BERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -916,7 +913,6 @@ "shape_list", ] # TensorFlow models structure - _import_structure["models.albert"].extend( [ "TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", From 3e78efb577e20744132db0196ee2c9a79a00cf65 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 19 Mar 2021 09:52:50 -0400 Subject: [PATCH 4/6] Update utils/custom_init_isort.py Co-authored-by: Lysandre Debut --- utils/custom_init_isort.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py index 1891623f283f..06a89b166a5a 100644 --- a/utils/custom_init_isort.py +++ b/utils/custom_init_isort.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. +# Copyright 2021 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From ff01e893fd0527795d2107aaa0d909cbdc61a5fc Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Fri, 19 Mar 2021 09:54:35 -0400 Subject: [PATCH 5/6] Separate scripts that change content from quality --- Makefile | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 50f3c14bae78..dfa92755f074 100644 --- a/Makefile +++ b/Makefile @@ -21,12 +21,11 @@ deps_table_update: # Check that source code meets quality standards -extra_quality_checks: deps_table_update +extra_quality_checks: python utils/check_copies.py python utils/check_table.py python utils/check_dummies.py python utils/check_repo.py - python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only python utils/class_mapping_update.py # this target runs checks on all files @@ -39,15 +38,19 @@ quality: # Format source code automatically and check is there are any problems left that need manual fixing -style: deps_table_update - black $(check_dirs) - isort $(check_dirs) +extra_style_checks: deps_table_update python utils/custom_init_isort.py python utils/style_doc.py src/transformers docs/source --max_len 119 +# this target runs checks on all files +style: + black $(check_dirs) + isort $(check_dirs) + ${MAKE} extra_style_checks + # Super fast fix and check target that only works on relevant modified files since the branch was made -fixup: modified_only_fixup extra_quality_checks +fixup: modified_only_fixup extra_style_checks extra_quality_checks # Make marked copies of snippets of codes conform to the original From 66ba6b07339fbbc91cdda5752ce28676225988b7 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Fri, 19 Mar 2021 13:42:18 -0400 Subject: [PATCH 6/6] Move class_mapping_update to style_checks --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index dfa92755f074..b659fcb546a3 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,6 @@ extra_quality_checks: python utils/check_table.py python utils/check_dummies.py python utils/check_repo.py - python utils/class_mapping_update.py # this target runs checks on all files quality: @@ -41,6 +40,7 @@ quality: extra_style_checks: deps_table_update python utils/custom_init_isort.py python utils/style_doc.py src/transformers docs/source --max_len 119 + python utils/class_mapping_update.py # this target runs checks on all files style: