diff --git a/.gitignore b/.gitignore index 36cbb4f7ea39..965fbeec77f5 100644 --- a/.gitignore +++ b/.gitignore @@ -9,8 +9,7 @@ __pycache__/ *.so # tests and logs -tests/fixtures/* -!tests/fixtures/sample_text_no_unicode.txt +tests/fixtures/cached_*_text.txt logs/ lightning_logs/ lang_code_data/ diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index 464730108624..e0e76c77958d 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -44,6 +44,13 @@ AutoTokenizer :members: +AutoFeatureExtractor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.AutoFeatureExtractor + :members: + + AutoModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0b9d366d3cfb..264ff7b0c456 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -45,6 +45,7 @@ _BaseLazyModule, is_flax_available, is_sentencepiece_available, + is_speech_available, is_tf_available, is_tokenizers_available, is_torch_available, @@ -102,6 +103,7 @@ "is_py3nvml_available", "is_sentencepiece_available", "is_sklearn_available", + "is_speech_available", "is_tf_available", "is_tokenizers_available", "is_torch_available", @@ -133,9 +135,11 @@ "models.auto": [ "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", + "FEATURE_EXTRACTOR_MAPPING", "MODEL_NAMES_MAPPING", "TOKENIZER_MAPPING", "AutoConfig", + "AutoFeatureExtractor", "AutoTokenizer", ], "models.bart": ["BartConfig", "BartTokenizer"], @@ -202,7 +206,6 @@ "models.speech_to_text": [ "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig", - "Speech2TextFeatureExtractor", ], "models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"], "models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"], @@ -288,7 +291,6 @@ _import_structure["models.pegasus"].append("PegasusTokenizer") _import_structure["models.reformer"].append("ReformerTokenizer") _import_structure["models.speech_to_text"].append("Speech2TextTokenizer") - _import_structure["models.speech_to_text"].append("Speech2TextProcessor") _import_structure["models.t5"].append("T5Tokenizer") _import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer") @@ -339,6 +341,7 @@ if is_sentencepiece_available(): _import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"] + else: from .utils import dummy_tokenizers_objects @@ -346,6 +349,20 @@ name for name in dir(dummy_tokenizers_objects) if not name.startswith("_") ] +# Speech-specific objects +if is_speech_available(): + _import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor") + + if is_sentencepiece_available(): + _import_structure["models.speech_to_text"].append("Speech2TextProcessor") + +else: + from .utils import dummy_speech_objects + + _import_structure["utils.dummy_speech_objects"] = [ + name for name in dir(dummy_speech_objects) if not name.startswith("_") + ] + # Vision-specific objects if is_vision_available(): _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] @@ -1394,6 +1411,7 @@ is_py3nvml_available, is_sentencepiece_available, is_sklearn_available, + is_speech_available, is_tf_available, is_tokenizers_available, is_torch_available, @@ -1429,9 +1447,11 @@ from .models.auto import ( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, + FEATURE_EXTRACTOR_MAPPING, MODEL_NAMES_MAPPING, TOKENIZER_MAPPING, AutoConfig, + AutoFeatureExtractor, AutoTokenizer, ) from .models.bart import BartConfig, BartTokenizer @@ -1494,11 +1514,7 @@ from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer - from .models.speech_to_text import ( - SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, - Speech2TextConfig, - Speech2TextFeatureExtractor, - ) + from .models.speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer @@ -1585,7 +1601,7 @@ from .models.mt5 import MT5Tokenizer from .models.pegasus import PegasusTokenizer from .models.reformer import ReformerTokenizer - from .models.speech_to_text import Speech2TextProcessor, Speech2TextTokenizer + from .models.speech_to_text import Speech2TextTokenizer from .models.t5 import T5Tokenizer from .models.xlm_prophetnet import XLMProphetNetTokenizer from .models.xlm_roberta import XLMRobertaTokenizer @@ -1627,9 +1643,19 @@ if is_sentencepiece_available(): from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer + else: from .utils.dummy_tokenizers_objects import * + if is_speech_available(): + from .models.speech_to_text import Speech2TextFeatureExtractor + + if is_sentencepiece_available(): + from .models.speech_to_text import Speech2TextProcessor + + else: + from .utils.dummy_speech_objects import * + if is_vision_available(): from .image_utils import ImageFeatureExtractionMixin from .models.vit import ViTFeatureExtractor diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index c7a4bd41d644..b53407ad3eed 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -43,6 +43,7 @@ "sphinx-copybutton": "sphinx-copybutton", "sphinx-markdown-tables": "sphinx-markdown-tables", "sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3", + "sphinxext-opengraph": "sphinxext-opengraph==0.4.1", "sphinx": "sphinx==3.2.1", "starlette": "starlette", "tensorflow-cpu": "tensorflow-cpu>=2.3", diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index dbd5f9a6ccd3..f7bf49c4009d 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -325,6 +325,13 @@ def get_feature_extractor_dict( local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + + user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True @@ -349,6 +356,7 @@ def get_feature_extractor_dict( resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, + user_agent=user_agent, ) # Load feature_extractor dict with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader: @@ -426,6 +434,7 @@ def to_dict(self) -> Dict[str, Any]: :obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance. """ output = copy.deepcopy(self.__dict__) + output["feature_extractor_type"] = self.__class__.__name__ return output diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index ed4b84dc108d..bba9afc3a421 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -397,6 +397,11 @@ def is_torchaudio_available(): return _torchaudio_available +def is_speech_available(): + # For now this depends on torchaudio but the exact dependency might evolve in the future. + return _torchaudio_available + + def torch_only_method(fn): def wrapper(*args, **kwargs): if not _torch_available: @@ -513,6 +518,13 @@ def wrapper(*args, **kwargs): """ +# docstyle-ignore +SPEECH_IMPORT_ERROR = """ +{0} requires the torchaudio library but it was not found in your environment. You can install it with pip: +`pip install torchaudio` +""" + + # docstyle-ignore VISION_IMPORT_ERROR = """ {0} requires the PIL library but it was not found in your environment. You can install it with pip: @@ -586,6 +598,12 @@ def requires_scatter(obj): raise ImportError(SCATTER_IMPORT_ERROR.format(name)) +def requires_speech(obj): + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + if not is_speech_available(): + raise ImportError(SPEECH_IMPORT_ERROR.format(name)) + + def requires_vision(obj): name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ if not is_vision_available(): diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 8bf312231a75..ef255d8b268d 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -23,6 +23,7 @@ _import_structure = { "configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"], + "feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"], "tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"], } @@ -104,6 +105,7 @@ if TYPE_CHECKING: from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig + from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer if is_torch_available(): diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py new file mode 100644 index 000000000000..097a336c96db --- /dev/null +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -0,0 +1,150 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" AutoFeatureExtractor class. """ + +from collections import OrderedDict + +from ...feature_extraction_utils import FeatureExtractionMixin +from ...file_utils import is_speech_available, is_vision_available +from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor +from .configuration_auto import replace_list_option_in_docstrings + + +if is_speech_available(): + from ..speech_to_text.feature_extraction_speech_to_text import Speech2TextFeatureExtractor +else: + Speech2TextFeatureExtractor = None + +if is_vision_available(): + from ..vit.feature_extraction_vit import ViTFeatureExtractor +else: + ViTFeatureExtractor = None + + +# Build the list of all feature extractors +FEATURE_EXTRACTOR_MAPPING = OrderedDict( + [ + ("s2t", Speech2TextFeatureExtractor), + ("vit", ViTFeatureExtractor), + ("wav2vec2", Wav2Vec2FeatureExtractor), + ] +) + + +def feature_extractor_class_from_name(class_name: str): + for c in FEATURE_EXTRACTOR_MAPPING.values(): + if c is not None and c.__name__ == class_name: + return c + + +class AutoFeatureExtractor: + r""" + This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the + library when created with the :meth:`AutoFeatureExtractor.from_pretrained` class method. + + This class cannot be instantiated directly using ``__init__()`` (throws an error). + """ + + def __init__(self): + raise EnvironmentError( + "AutoFeatureExtractor is designed to be instantiated " + "using the `AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)` method." + ) + + @classmethod + @replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING) + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary. + + The tokenizer class to instantiate is selected based on the :obj:`model_type` property of the config object + (either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's + missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: + + List options + + Params: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + This can be either: + + - a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or + namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing a feature extractor file saved using the + :func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g., + ``./my_model_directory/``. + - a path or url to a saved feature extractor JSON `file`, e.g., + ``./my_model_directory/feature_extraction_config.json``. + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model feature extractor should be cached if the + standard cache should not be used. + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force to (re-)download the feature extractor files and override the cached versions + if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file + exists. + proxies (:obj:`Dict[str, str]`, `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (:obj:`str` or `bool`, `optional`): + The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token + generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`): + If :obj:`False`, then this function returns just the final feature extractor object. If :obj:`True`, + then this functions returns a :obj:`Tuple(feature_extractor, unused_kwargs)` where `unused_kwargs` is a + dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the + part of ``kwargs`` which has not been used to update ``feature_extractor`` and is otherwise ignored. + kwargs (:obj:`Dict[str, Any]`, `optional`): + The values in kwargs of any keys which are feature extractor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is + controlled by the ``return_unused_kwargs`` keyword parameter. + + .. note:: + + Passing :obj:`use_auth_token=True` is required when you want to use a private model. + + Examples:: + + >>> from transformers import AutoFeatureExtractor + + >>> # Download vocabulary from huggingface.co and cache. + >>> feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h') + + >>> # If vocabulary files are in a directory (e.g. feature extractor was saved using `save_pretrained('./test/saved_model/')`) + >>> feature_extractor = AutoFeatureExtractor.from_pretrained('./test/saved_model/') + + """ + kwargs["_from_auto"] = True + config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) + + if "feature_extractor_type" in config_dict: + feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"]) + return feature_extractor_class.from_dict(config_dict, **kwargs) + else: + # Fallback: use pattern matching on the string. + for pattern, feature_extractor_class in FEATURE_EXTRACTOR_MAPPING.items(): + if pattern in str(pretrained_model_name_or_path): + return feature_extractor_class.from_dict(config_dict, **kwargs) + + raise ValueError( + f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in " + "its feature_extraction_config.json, or contain one of the following strings " + f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}" + ) diff --git a/src/transformers/models/speech_to_text/__init__.py b/src/transformers/models/speech_to_text/__init__.py index 0defd14c0032..026312e8cdab 100644 --- a/src/transformers/models/speech_to_text/__init__.py +++ b/src/transformers/models/speech_to_text/__init__.py @@ -17,7 +17,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_torch_available +from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_speech_available, is_torch_available _import_structure = { @@ -25,13 +25,17 @@ "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "Speech2TextConfig", ], - "feature_extraction_speech_to_text": ["Speech2TextFeatureExtractor"], } if is_sentencepiece_available(): - _import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"] _import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"] +if is_speech_available(): + _import_structure["feature_extraction_speech_to_text"] = ["Speech2TextFeatureExtractor"] + + if is_sentencepiece_available(): + _import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"] + if is_torch_available(): _import_structure["modeling_speech_to_text"] = [ "SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -43,12 +47,16 @@ if TYPE_CHECKING: from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig - from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor if is_sentencepiece_available(): - from .processing_speech_to_text import Speech2TextProcessor from .tokenization_speech_to_text import Speech2TextTokenizer + if is_speech_available(): + from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor + + if is_sentencepiece_available(): + from .processing_speech_to_text import Speech2TextProcessor + if is_torch_available(): from .modeling_speech_to_text import ( SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py index e7fdb44aefe4..a7c21a969f9c 100644 --- a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py +++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py @@ -19,19 +19,15 @@ from typing import List, Optional, Union import numpy as np +import torch +import torchaudio.compliance.kaldi as ta_kaldi from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_utils import BatchFeature -from ...file_utils import PaddingStrategy, TensorType, is_torch_available, is_torchaudio_available +from ...file_utils import PaddingStrategy, TensorType from ...utils import logging -if is_torch_available(): - import torch - -if is_torchaudio_available(): - import torchaudio.compliance.kaldi as ta_kaldi - logger = logging.get_logger(__name__) @@ -75,8 +71,6 @@ def __init__( normalize_vars=True, **kwargs ): - if not is_torchaudio_available(): - raise ImportError("`Speech2TextFeatureExtractor` requires torchaudio: `pip install torchaudio`.") super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) self.num_mel_bins = num_mel_bins self.do_ceptral_normalize = do_ceptral_normalize diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py index 2ef3165d7f08..8dc02dae0977 100644 --- a/src/transformers/utils/dummy_sentencepiece_objects.py +++ b/src/transformers/utils/dummy_sentencepiece_objects.py @@ -110,11 +110,6 @@ def from_pretrained(self, *args, **kwargs): requires_sentencepiece(self) -class Speech2TextProcessor: - def __init__(self, *args, **kwargs): - requires_sentencepiece(self) - - class Speech2TextTokenizer: def __init__(self, *args, **kwargs): requires_sentencepiece(self) diff --git a/src/transformers/utils/dummy_speech_objects.py b/src/transformers/utils/dummy_speech_objects.py new file mode 100644 index 000000000000..45021250cd0e --- /dev/null +++ b/src/transformers/utils/dummy_speech_objects.py @@ -0,0 +1,12 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..file_utils import requires_speech + + +class Speech2TextFeatureExtractor: + def __init__(self, *args, **kwargs): + requires_speech(self) + + +class Speech2TextProcessor: + def __init__(self, *args, **kwargs): + requires_speech(self) diff --git a/tests/fixtures/dummy_feature_extractor_config.json b/tests/fixtures/dummy_feature_extractor_config.json new file mode 100644 index 000000000000..cf0c5dce6c42 --- /dev/null +++ b/tests/fixtures/dummy_feature_extractor_config.json @@ -0,0 +1,3 @@ +{ + "feature_extractor_type": "Wav2Vec2FeatureExtractor" +} \ No newline at end of file diff --git a/tests/test_feature_extraction_auto.py b/tests/test_feature_extraction_auto.py new file mode 100644 index 000000000000..71ee32c230af --- /dev/null +++ b/tests/test_feature_extraction_auto.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# Copyright 2021 the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from transformers import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor, Wav2Vec2FeatureExtractor + + +SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json" +) + + +class AutoFeatureExtractorTest(unittest.TestCase): + def test_feature_extractor_from_model_shortcut(self): + config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") + self.assertIsInstance(config, Wav2Vec2FeatureExtractor) + + def test_feature_extractor_from_local_file(self): + config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG) + self.assertIsInstance(config, Wav2Vec2FeatureExtractor) + + def test_pattern_matching_fallback(self): + """ + In cases where config.json doesn't include a model_type, + perform a few safety checks on the config mapping's order. + """ + # no key string should be included in a later key string (typical failure case) + keys = list(FEATURE_EXTRACTOR_MAPPING.keys()) + for i, key in enumerate(keys): + self.assertFalse(any(key in later_key for later_key in keys[i + 1 :])) diff --git a/tests/test_feature_extraction_speech_to_text.py b/tests/test_feature_extraction_speech_to_text.py index 5cd2f67f457d..c90beef01377 100644 --- a/tests/test_feature_extraction_speech_to_text.py +++ b/tests/test_feature_extraction_speech_to_text.py @@ -20,12 +20,15 @@ import numpy as np -from transformers import Speech2TextFeatureExtractor +from transformers import is_speech_available from transformers.testing_utils import require_torch, require_torchaudio from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin +if is_speech_available(): + from transformers import Speech2TextFeatureExtractor + global_rng = random.Random() @@ -101,7 +104,7 @@ def _flatten(list_of_lists): @require_torchaudio class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): - feature_extraction_class = Speech2TextFeatureExtractor + feature_extraction_class = Speech2TextFeatureExtractor if is_speech_available() else None def setUp(self): self.feat_extract_tester = Speech2TextFeatureExtractionTester(self) diff --git a/tests/test_processor_speech_to_text.py b/tests/test_processor_speech_to_text.py index cf26e32c1db4..76a7a7446152 100644 --- a/tests/test_processor_speech_to_text.py +++ b/tests/test_processor_speech_to_text.py @@ -19,7 +19,7 @@ from pathlib import Path from shutil import copyfile -from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor, Speech2TextTokenizer +from transformers import Speech2TextTokenizer, is_speech_available from transformers.file_utils import FEATURE_EXTRACTOR_NAME from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio @@ -27,6 +27,10 @@ from .test_feature_extraction_speech_to_text import floats_list +if is_speech_available(): + from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor + + SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") diff --git a/utils/check_dummies.py b/utils/check_dummies.py index 20b348cea166..e2d16713d5fe 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -26,7 +26,7 @@ _re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$") -BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers", "vision"] +BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"] DUMMY_CONSTANT = """ diff --git a/utils/check_inits.py b/utils/check_inits.py index 7d024ed39515..969c8a07ffe3 100644 --- a/utils/check_inits.py +++ b/utils/check_inits.py @@ -18,7 +18,7 @@ PATH_TO_TRANSFORMERS = "src/transformers" -BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers", "vision"] +BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"] # Catches a line with a key-values pattern: "bla": ["foo", "bar"] _re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')