Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto feature extractor #11097

Merged
merged 7 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ __pycache__/
*.so

# tests and logs
tests/fixtures/*
!tests/fixtures/sample_text_no_unicode.txt
tests/fixtures/cached_*_text.txt
logs/
lightning_logs/
lang_code_data/
Expand Down
7 changes: 7 additions & 0 deletions docs/source/model_doc/auto.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ AutoTokenizer
:members:


AutoFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.AutoFeatureExtractor
:members:


AutoModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
42 changes: 34 additions & 8 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_BaseLazyModule,
is_flax_available,
is_sentencepiece_available,
is_speech_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
Expand Down Expand Up @@ -102,6 +103,7 @@
"is_py3nvml_available",
"is_sentencepiece_available",
"is_sklearn_available",
"is_speech_available",
"is_tf_available",
"is_tokenizers_available",
"is_torch_available",
Expand Down Expand Up @@ -133,9 +135,11 @@
"models.auto": [
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"CONFIG_MAPPING",
"FEATURE_EXTRACTOR_MAPPING",
"MODEL_NAMES_MAPPING",
"TOKENIZER_MAPPING",
"AutoConfig",
"AutoFeatureExtractor",
"AutoTokenizer",
],
"models.bart": ["BartConfig", "BartTokenizer"],
Expand Down Expand Up @@ -202,7 +206,6 @@
"models.speech_to_text": [
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig",
"Speech2TextFeatureExtractor",
],
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
Expand Down Expand Up @@ -288,7 +291,6 @@
_import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer")
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
_import_structure["models.t5"].append("T5Tokenizer")
_import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer")
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer")
Expand Down Expand Up @@ -339,13 +341,28 @@

if is_sentencepiece_available():
_import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]

else:
from .utils import dummy_tokenizers_objects

_import_structure["utils.dummy_tokenizers_objects"] = [
name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
]

# Speech-specific objects
if is_speech_available():
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")

if is_sentencepiece_available():
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")

else:
from .utils import dummy_speech_objects

_import_structure["utils.dummy_speech_objects"] = [
name for name in dir(dummy_speech_objects) if not name.startswith("_")
]

# Vision-specific objects
if is_vision_available():
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
Expand Down Expand Up @@ -1394,6 +1411,7 @@
is_py3nvml_available,
is_sentencepiece_available,
is_sklearn_available,
is_speech_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
Expand Down Expand Up @@ -1429,9 +1447,11 @@
from .models.auto import (
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONFIG_MAPPING,
FEATURE_EXTRACTOR_MAPPING,
MODEL_NAMES_MAPPING,
TOKENIZER_MAPPING,
AutoConfig,
AutoFeatureExtractor,
AutoTokenizer,
)
from .models.bart import BartConfig, BartTokenizer
Expand Down Expand Up @@ -1494,11 +1514,7 @@
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
from .models.speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Speech2TextConfig,
Speech2TextFeatureExtractor,
)
from .models.speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
Expand Down Expand Up @@ -1585,7 +1601,7 @@
from .models.mt5 import MT5Tokenizer
from .models.pegasus import PegasusTokenizer
from .models.reformer import ReformerTokenizer
from .models.speech_to_text import Speech2TextProcessor, Speech2TextTokenizer
from .models.speech_to_text import Speech2TextTokenizer
from .models.t5 import T5Tokenizer
from .models.xlm_prophetnet import XLMProphetNetTokenizer
from .models.xlm_roberta import XLMRobertaTokenizer
Expand Down Expand Up @@ -1627,9 +1643,19 @@

if is_sentencepiece_available():
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer

else:
from .utils.dummy_tokenizers_objects import *

if is_speech_available():
from .models.speech_to_text import Speech2TextFeatureExtractor

if is_sentencepiece_available():
from .models.speech_to_text import Speech2TextProcessor

else:
from .utils.dummy_speech_objects import *

if is_vision_available():
from .image_utils import ImageFeatureExtractionMixin
from .models.vit import ViTFeatureExtractor
Expand Down
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"sphinx-copybutton": "sphinx-copybutton",
"sphinx-markdown-tables": "sphinx-markdown-tables",
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
"sphinxext-opengraph": "sphinxext-opengraph==0.4.1",
"sphinx": "sphinx==3.2.1",
"starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.3",
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ def get_feature_extractor_dict(
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)

from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)

user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline

if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
Expand All @@ -349,6 +356,7 @@ def get_feature_extractor_dict(
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
# Load feature_extractor dict
with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
Expand Down Expand Up @@ -426,6 +434,7 @@ def to_dict(self) -> Dict[str, Any]:
:obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
"""
output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__

return output

Expand Down
18 changes: 18 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,11 @@ def is_torchaudio_available():
return _torchaudio_available


def is_speech_available():
# For now this depends on torchaudio but the exact dependency might evolve in the future.
return _torchaudio_available


def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
Expand Down Expand Up @@ -513,6 +518,13 @@ def wrapper(*args, **kwargs):
"""


# docstyle-ignore
SPEECH_IMPORT_ERROR = """
{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
`pip install torchaudio`
"""


# docstyle-ignore
VISION_IMPORT_ERROR = """
{0} requires the PIL library but it was not found in your environment. You can install it with pip:
Expand Down Expand Up @@ -586,6 +598,12 @@ def requires_scatter(obj):
raise ImportError(SCATTER_IMPORT_ERROR.format(name))


def requires_speech(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_speech_available():
raise ImportError(SPEECH_IMPORT_ERROR.format(name))


def requires_vision(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_vision_available():
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

_import_structure = {
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
}

Expand Down Expand Up @@ -104,6 +105,7 @@

if TYPE_CHECKING:
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer

if is_torch_available():
Expand Down
Loading