From 30677dc743bae5485f755392adf4d543f71db6e4 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Thu, 1 Apr 2021 17:16:05 +0200 Subject: [PATCH] Add Vision Transformer and ViTFeatureExtractor (#10950) * Squash all commits into one * Update ViTFeatureExtractor to use image_utils instead of torchvision * Remove torchvision and add Pillow * Small docs improvement * Address most comments by @sgugger * Fix tests * Clean up conversion script * Pooler first draft * Fix quality * Improve conversion script * Make style and quality * Make fix-copies * Minor docs improvements * Should use fix-copies instead of manual handling * Revert "Should use fix-copies instead of manual handling" This reverts commit fd4e591bce4496d41406425c82606a8fdaf8a50b. * Place ViT in alphabetical order Co-authored-by: Lysandre Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .circleci/config.yml | 16 +- README.md | 1 + docs/source/index.rst | 19 +- docs/source/model_doc/vit.rst | 102 ++ setup.py | 4 +- src/transformers/__init__.py | 24 +- src/transformers/dependency_versions_table.py | 1 + src/transformers/file_utils.py | 5 +- src/transformers/image_utils.py | 4 +- src/transformers/models/__init__.py | 1 + src/transformers/models/auto/__init__.py | 4 + .../models/auto/configuration_auto.py | 4 + src/transformers/models/auto/modeling_auto.py | 107 ++ src/transformers/models/vit/__init__.py | 70 ++ .../models/vit/configuration_vit.py | 116 ++ .../models/vit/convert_vit_timm_to_pytorch.py | 228 ++++ .../models/vit/feature_extraction_vit.py | 130 +++ src/transformers/models/vit/modeling_vit.py | 629 +++++++++++ src/transformers/utils/dummy_pt_objects.py | 29 + .../utils/dummy_vision_objects.py | 5 + src/transformers/utils/imagenet_classes.py | 1003 +++++++++++++++++ tests/test_feature_extraction_vit.py | 221 ++++ tests/test_image_utils.py | 4 +- tests/test_modeling_common.py | 2 + tests/test_modeling_vit.py | 365 ++++++ 25 files changed, 3072 insertions(+), 22 deletions(-) create mode 100644 docs/source/model_doc/vit.rst create mode 100644 src/transformers/models/vit/__init__.py create mode 100644 src/transformers/models/vit/configuration_vit.py create mode 100644 src/transformers/models/vit/convert_vit_timm_to_pytorch.py create mode 100644 src/transformers/models/vit/feature_extraction_vit.py create mode 100644 src/transformers/models/vit/modeling_vit.py create mode 100644 src/transformers/utils/imagenet_classes.py create mode 100644 tests/test_feature_extraction_vit.py create mode 100644 tests/test_modeling_vit.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 28b4f52abd3d..56d551a9465a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -80,8 +80,8 @@ jobs: - v0.4-{{ checksum "setup.py" }} - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - - run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,speech] - - run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html + - run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,speech,vision] + - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html - save_cache: key: v0.4-{{ checksum "setup.py" }} paths: @@ -110,8 +110,8 @@ jobs: - v0.4-{{ checksum "setup.py" }} - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - - run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech] - - run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html + - run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech,vision] + - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html - save_cache: key: v0.4-{{ checksum "setup.py" }} paths: @@ -139,8 +139,8 @@ jobs: - v0.4-{{ checksum "setup.py" }} - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - - run: pip install .[sklearn,torch,testing,sentencepiece,speech] - - run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html + - run: pip install .[sklearn,torch,testing,sentencepiece,speech,vision] + - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html - save_cache: key: v0.4-torch-{{ checksum "setup.py" }} paths: @@ -223,8 +223,8 @@ jobs: - v0.4-{{ checksum "setup.py" }} - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: pip install --upgrade pip - - run: pip install .[sklearn,torch,testing,sentencepiece,speech] - - run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html + - run: pip install .[sklearn,torch,testing,sentencepiece,speech,vision] + - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html - save_cache: key: v0.4-torch-{{ checksum "setup.py" }} paths: diff --git a/README.md b/README.md index a643fe825307..dd535688cb93 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. 1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. +1. **[Vision Transformer (ViT)](https://huggingface.co/transformers/model_doc/vit.html)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau. 1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. diff --git a/docs/source/index.rst b/docs/source/index.rst index 03652a77cae4..16164a761ae4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -210,22 +210,26 @@ and conversion utilities for the following models: 43. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -44. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +44. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 + Words: Transformers for Image Recognition at Scale `__ by Alexey Dosovitskiy, + Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias + Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. +45. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -45. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +46. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -46. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +47. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -47. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +48. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -48. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +49. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -49. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +50. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -328,6 +332,8 @@ TensorFlow and/or Flax. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| ViT | ❌ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Wav2Vec2 | ✅ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | XLM | ✅ | ❌ | ✅ | ✅ | ❌ | @@ -460,6 +466,7 @@ TensorFlow and/or Flax. model_doc/t5 model_doc/tapas model_doc/transformerxl + model_doc/vit model_doc/wav2vec2 model_doc/xlm model_doc/xlmprophetnet diff --git a/docs/source/model_doc/vit.rst b/docs/source/model_doc/vit.rst new file mode 100644 index 000000000000..831d4f484de7 --- /dev/null +++ b/docs/source/model_doc/vit.rst @@ -0,0 +1,102 @@ +.. + Copyright 2020 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +Vision Transformer (ViT) +----------------------------------------------------------------------------------------------------------------------- + +.. note:: + + This is a recently introduced model so the API hasn't been tested extensively. There may be some bugs or slight + breaking changes to fix it in the future. If you see something strange, file a `Github Issue + `__. + + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The Vision Transformer (ViT) model was proposed in `An Image is Worth 16x16 Words: Transformers for Image Recognition +at Scale `__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk +Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob +Uszkoreit, Neil Houlsby. It's the first paper that successfully trains a Transformer encoder on ImageNet, attaining +very good results compared to familiar convolutional architectures. + + +The abstract from the paper is the following: + +*While the Transformer architecture has become the de-facto standard for natural language processing tasks, its +applications to computer vision remain limited. In vision, attention is either applied in conjunction with +convolutional networks, or used to replace certain components of convolutional networks while keeping their overall +structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to +sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of +data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.), +Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring +substantially fewer computational resources to train.* + +Tips: + +- To feed images to the Transformer encoder, each image is split into a sequence of fixed-size non-overlapping patches, + which are then linearly embedded. A [CLS] token is added to serve as representation of an entire image, which can be + used for classification. The authors also add absolute position embeddings, and feed the resulting sequence of + vectors to a standard Transformer encoder. +- The Vision Transformer was pre-trained using a resolution of 224x224. During fine-tuning, it is often beneficial to + use a higher resolution than pre-training `(Touvron et al., 2019) `__, `(Kolesnikov + et al., 2020) `__. The authors report the best results with a resolution of 384x384 + during fine-tuning. +- As the Vision Transformer expects each image to be of the same size (resolution), one can use + :class:`~transformers.ViTFeatureExtractor` to resize (or rescale) and normalize images for the model. +- Both the patch resolution and image resolution used during pre-training or fine-tuning are reflected in the name of + each checkpoint. For example, :obj:`google/vit-base-patch16-224` refers to a base-sized architecture with patch + resolution of 16x16 and fine-tuning resolution of 224x224. All checkpoints can be found on the `hub + `__. +- The available checkpoints are either (1) pre-trained on `ImageNet-21k `__ (a collection of + 14 million images and 21k classes) only, or (2) also fine-tuned on `ImageNet + `__ (also referred to as ILSVRC 2012, a collection of 1.3 million + images and 1,000 classes). +- The best results are obtained with supervised pre-training, which is not the case in NLP. The authors also performed + an experiment with a self-supervised pre-training objective, namely masked patched prediction (inspired by masked + language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant + improvement of 2% to training from scratch, but still 4% behind supervised pre-training. + + +The original code (written in JAX) can be found `here `__. + +Note that we converted the weights from Ross Wightman's `timm library +`__, who already converted the weights from JAX to PyTorch. Credits +go to him! + + +ViTConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ViTConfig + :members: + + +ViTFeatureExtractor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ViTFeatureExtractor + :members: __call__ + + +ViTModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ViTModel + :members: forward + + +ViTForImageClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ViTForImageClassification + :members: forward diff --git a/setup.py b/setup.py index d25376fa7caa..cbf1bc4ecb3c 100644 --- a/setup.py +++ b/setup.py @@ -107,6 +107,7 @@ "onnxruntime>=1.4.0", "packaging", "parameterized", + "Pillow", "protobuf", "psutil", "pydantic", @@ -230,6 +231,7 @@ def run(self): extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") extras["speech"] = deps_list("soundfile", "torchaudio") +extras["vision"] = deps_list("Pillow") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["testing"] = ( @@ -242,7 +244,7 @@ def run(self): extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton") extras["quality"] = deps_list("black", "isort", "flake8") -extras["all"] = extras["tf"] + extras["torch"] + extras["flax"] + extras["sentencepiece"] + extras["tokenizers"] +extras["all"] = extras["tf"] + extras["torch"] + extras["flax"] + extras["sentencepiece"] + extras["tokenizers"] + extras["speech"] + extras["vision"] extras["dev"] = ( extras["all"] diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 39b65b70b795..f5954696e9ba 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -213,6 +213,7 @@ "TransfoXLCorpus", "TransfoXLTokenizer", ], + "models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], "models.wav2vec2": [ "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config", @@ -299,7 +300,7 @@ name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_") ] -# tokenziers-backed objects +# tokenizers-backed objects if is_tokenizers_available(): # Fast tokenizers _import_structure["models.convbert"].append("ConvBertTokenizerFast") @@ -348,6 +349,7 @@ # Vision-specific objects if is_vision_available(): _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] + _import_structure["models.vit"].append("ViTFeatureExtractor") else: from .utils import dummy_vision_objects @@ -426,6 +428,7 @@ _import_structure["models.auto"].extend( [ "MODEL_FOR_CAUSAL_LM_MAPPING", + "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", @@ -867,6 +870,14 @@ "load_tf_weights_in_transfo_xl", ] ) + _import_structure["models.vit"].extend( + [ + "VIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "ViTForImageClassification", + "ViTModel", + "ViTPreTrainedModel", + ] + ) _import_structure["models.wav2vec2"].extend( [ "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1311,7 +1322,6 @@ name for name in dir(dummy_flax_objects) if not name.startswith("_") ] - # Direct imports for type-checking if TYPE_CHECKING: # Configuration @@ -1479,6 +1489,7 @@ TransfoXLCorpus, TransfoXLTokenizer, ) + from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig from .models.wav2vec2 import ( WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config, @@ -1601,6 +1612,7 @@ if is_vision_available(): from .image_utils import ImageFeatureExtractionMixin + from .models.vit import ViTFeatureExtractor else: from .utils.dummy_vision_objects import * @@ -1666,6 +1678,7 @@ ) from .models.auto import ( MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, @@ -2025,6 +2038,12 @@ TransfoXLPreTrainedModel, load_tf_weights_in_transfo_xl, ) + from .models.vit import ( + VIT_PRETRAINED_MODEL_ARCHIVE_LIST, + ViTForImageClassification, + ViTModel, + ViTPreTrainedModel, + ) from .models.wav2vec2 import ( WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2ForCTC, @@ -2400,6 +2419,7 @@ # Import the same objects as dummies to get them in the namespace. # They will raise an import error if the user tries to instantiate / use them. from .utils.dummy_flax_objects import * + else: import importlib import os diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 1b89ed9d5c3a..fafecff49898 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -24,6 +24,7 @@ "onnxruntime": "onnxruntime>=1.4.0", "packaging": "packaging", "parameterized": "parameterized", + "Pillow": "Pillow", "protobuf": "protobuf", "psutil": "psutil", "pydantic": "pydantic", diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 8e62eca94acb..24020ea8c7b6 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -175,10 +175,11 @@ except importlib_metadata.PackageNotFoundError: _soundfile_available = False -_torchaudio_available = importlib.util.find_spec("torchaudio") + +_torchaudio_available = importlib.util.find_spec("torchaudio") is not None try: _torchaudio_version = importlib_metadata.version("torchaudio") - logger.debug(f"Successfully imported soundfile version {_torchaudio_version}") + logger.debug(f"Successfully imported torchaudio version {_torchaudio_version}") except importlib_metadata.PackageNotFoundError: _torchaudio_available = False diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 8f54303c957c..2fd5b4528d76 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -120,9 +120,9 @@ def normalize(self, image, mean, std): if isinstance(image, np.ndarray): if not isinstance(mean, np.ndarray): - mean = np.array(mean) + mean = np.array(mean).astype(image.dtype) if not isinstance(std, np.ndarray): - std = np.array(std) + std = np.array(std).astype(image.dtype) elif is_torch_tensor(image): import torch diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 776c336f3f37..efc6aedef391 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -67,6 +67,7 @@ t5, tapas, transfo_xl, + vit, wav2vec2, xlm, xlm_roberta, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 0fd4e9041f3d..0a47a6cb2b80 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -29,6 +29,7 @@ if is_torch_available(): _import_structure["modeling_auto"] = [ "MODEL_FOR_CAUSAL_LM_MAPPING", + "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", @@ -42,6 +43,7 @@ "MODEL_WITH_LM_HEAD_MAPPING", "AutoModel", "AutoModelForCausalLM", + "AutoModelForImageClassification", "AutoModelForMaskedLM", "AutoModelForMultipleChoice", "AutoModelForNextSentencePrediction", @@ -90,6 +92,7 @@ if is_torch_available(): from .modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, @@ -103,6 +106,7 @@ MODEL_WITH_LM_HEAD_MAPPING, AutoModel, AutoModelForCausalLM, + AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForNextSentencePrediction, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9636d7a5ef63..b32140c7c1c1 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -68,6 +68,7 @@ from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig +from ..vit.configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig from ..xlm_prophetnet.configuration_xlm_prophetnet import ( @@ -85,6 +86,7 @@ GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, + VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -134,6 +136,7 @@ ("gpt_neo", GPTNeoConfig), ("big_bird", BigBirdConfig), ("speech_to_text", Speech2TextConfig), + ("vit", ViTConfig), ("wav2vec2", Wav2Vec2Config), ("m2m_100", M2M100Config), ("convbert", ConvBertConfig), @@ -189,6 +192,7 @@ ("gpt_neo", "GPT Neo"), ("big_bird", "BigBird"), ("speech_to_text", "Speech2Text"), + ("vit", "ViT"), ("wav2vec2", "Wav2Vec2"), ("m2m_100", "M2M100"), ("convbert", "ConvBERT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 600c8ece2d9d..aecd7aa96715 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -237,6 +237,7 @@ TapasModel, ) from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel +from ..vit.modeling_vit import ViTForImageClassification, ViTModel from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model from ..xlm.modeling_xlm import ( XLMForMultipleChoice, @@ -313,6 +314,7 @@ T5Config, TapasConfig, TransfoXLConfig, + ViTConfig, Wav2Vec2Config, XLMConfig, XLMProphetNetConfig, @@ -331,6 +333,7 @@ (GPTNeoConfig, GPTNeoModel), (BigBirdConfig, BigBirdModel), (Speech2TextConfig, Speech2TextModel), + (ViTConfig, ViTModel), (Wav2Vec2Config, Wav2Vec2Model), (M2M100Config, M2M100Model), (ConvBertConfig, ConvBertModel), @@ -490,6 +493,13 @@ ] ) +MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict( + [ + # Model for Image Classification mapping + (ViTConfig, ViTForImageClassification), + ] +) + MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( [ # Model for Masked LM mapping @@ -1864,3 +1874,100 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys())}." ) + + +class AutoModelForImageClassification: + r""" + This is a generic model class that will be instantiated as one of the model classes of the library---with an image + classification head---when created with the :meth:`~transformers.AutoModelForImageClassification.from_pretrained` + class method or the :meth:`~transformers.AutoModelForImageClassification.from_config` class method. + + This class cannot be instantiated directly using ``__init__()`` (throws an error). + """ + + def __init__(self): + raise EnvironmentError( + "AutoModelForImageClassification is designed to be instantiated " + "using the `AutoModelForImageClassification.from_pretrained(pretrained_model_name_or_path)` or " + "`AutoModelForImageClassification.from_config(config)` methods." + ) + + @classmethod + @replace_list_option_in_docstrings(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, use_model_types=False) + def from_config(cls, config): + r""" + Instantiates one of the model classes of the library---with an image classification head---from a + configuration. + + Note: + Loading a model from its configuration file does **not** load the model weights. It only affects the + model's configuration. Use :meth:`~transformers.AutoModelForImageClassification.from_pretrained` to load + the model weights. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The model class to instantiate is selected based on the configuration class: + + List options + + Examples:: + + >>> from transformers import AutoConfig, AutoModelForImageClassification + >>> # Download configuration from huggingface.co and cache. + >>> config = AutoConfig.from_pretrained('google/vit_base_patch16_224') + >>> model = AutoModelForImageClassification.from_config(config) + """ + if type(config) in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys(): + return MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING[type(config)](config) + raise ValueError( + "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, + cls.__name__, + ", ".join(c.__name__ for c in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()), + ) + ) + + @classmethod + @replace_list_option_in_docstrings(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING) + @add_start_docstrings( + "Instantiate one of the model classes of the library---with an image classification head---from a " + "pretrained model.", + AUTO_MODEL_PRETRAINED_DOCSTRING, + ) + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Examples:: + + >>> from transformers import AutoConfig, AutoModelForImageClassification + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = AutoModelForImageClassification.from_pretrained('google/vit_base_patch16_224') + + >>> # Update configuration during loading + >>> model = AutoModelForImageClassification.from_pretrained('google/vit_base_patch16_224', output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) + >>> config = AutoConfig.from_json_file('./tf_model/vit_tf_model_config.json') + >>> model = AutoModelForImageClassification.from_pretrained('./tf_model/vit_tf_checkpoint.ckpt.index', from_tf=True, config=config) + """ + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) + + if type(config) in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys(): + return MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING[type(config)].from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + raise ValueError( + "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" + "Model type should be one of {}.".format( + config.__class__, + cls.__name__, + ", ".join(c.__name__ for c in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()), + ) + ) diff --git a/src/transformers/models/vit/__init__.py b/src/transformers/models/vit/__init__.py new file mode 100644 index 000000000000..a8164e2bfe59 --- /dev/null +++ b/src/transformers/models/vit/__init__.py @@ -0,0 +1,70 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _BaseLazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], +} + +if is_vision_available(): + _import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"] + +if is_torch_available(): + _import_structure["modeling_vit"] = [ + "VIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "ViTForImageClassification", + "ViTModel", + "ViTPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig + + if is_vision_available(): + from .feature_extraction_vit import ViTFeatureExtractor + + if is_torch_available(): + from .modeling_vit import ( + VIT_PRETRAINED_MODEL_ARCHIVE_LIST, + ViTForImageClassification, + ViTModel, + ViTPreTrainedModel, + ) + + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/vit/configuration_vit.py b/src/transformers/models/vit/configuration_vit.py new file mode 100644 index 000000000000..5e53df4cddfd --- /dev/null +++ b/src/transformers/models/vit/configuration_vit.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2021 Google AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" ViT model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VIT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "nielsr/vit-base-patch16-224": "https://huggingface.co/vit-base-patch16-224/resolve/main/config.json", + # See all ViT models at https://huggingface.co/models?filter=vit +} + + +class ViTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.ViTModel`. It is used to + instantiate an ViT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the ViT `google/vit-base-patch16-224 + `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + image_size (:obj:`int`, `optional`, defaults to :obj:`224`): + The size (resolution) of each image. + patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): + The size (resolution) of each patch. + num_channels (:obj:`int`, `optional`, defaults to :obj:`3`): + The number of input channels. + + + Example:: + + >>> from transformers import ViTModel, ViTConfig + + >>> # Initializing a ViT vit-base-patch16-224 style configuration + >>> configuration = ViTConfig() + + >>> # Initializing a model from the vit-base-patch16-224 style configuration + >>> model = ViTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "vit" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + is_encoder_decoder=False, + image_size=224, + patch_size=16, + num_channels=3, + **kwargs + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels diff --git a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py new file mode 100644 index 000000000000..06b5f1344684 --- /dev/null +++ b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py @@ -0,0 +1,228 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ViT checkpoints from the timm library.""" + + +import argparse +from pathlib import Path + +import torch +from PIL import Image + +import requests +import timm +from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification, ViTModel +from transformers.utils import logging +from transformers.utils.imagenet_classes import id2label + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, base_model=False): + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias")) + + # projection layer + position embeddings + rename_keys.extend( + [ + ("cls_token", "vit.embeddings.cls_token"), + ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"), + ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"), + ("pos_embed", "vit.embeddings.position_embeddings"), + ] + ) + + if base_model: + # layernorm + pooler + rename_keys.extend( + [ + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), + ("pre_logits.fc.weight", "pooler.dense.weight"), + ("pre_logits.fc.bias", "pooler.dense.bias"), + ] + ) + + # if just the base model, we should remove "vit" from all keys that start with "vit" + rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys] + else: + # layernorm + classification head + rename_keys.extend( + [ + ("norm.weight", "vit.layernorm.weight"), + ("norm.bias", "vit.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, base_model=False): + for i in range(config.num_hidden_layers): + if base_model: + prefix = "" + else: + prefix = "vit." + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +def remove_classification_head_(state_dict): + ignore_keys = ["head.weight", "head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our ViT structure. + """ + + # define default ViT configuration + config = ViTConfig() + base_model = False + # dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size + if vit_name[-5:] == "in21k": + base_model = True + config.patch_size = int(vit_name[-12:-10]) + config.image_size = int(vit_name[-9:-6]) + else: + config.num_labels = 1000 + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + config.patch_size = int(vit_name[-6:-4]) + config.image_size = int(vit_name[-3:]) + # size of the architecture + if vit_name[4:].startswith("small"): + config.hidden_size = 768 + config.intermediate_size = 2304 + config.num_hidden_layers = 8 + config.num_attention_heads = 8 + if vit_name[4:].startswith("base"): + pass + elif vit_name[4:].startswith("large"): + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + elif vit_name[4:].startswith("huge"): + config.hidden_size = 1280 + config.intermediate_size = 5120 + config.num_hidden_layers = 32 + config.num_attention_heads = 16 + + # load original model from timm + timm_model = timm.create_model(vit_name, pretrained=True) + timm_model.eval() + + # load state_dict of original model, remove and rename some keys + state_dict = timm_model.state_dict() + if base_model: + remove_classification_head_(state_dict) + rename_keys = create_rename_keys(config, base_model) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, base_model) + + # load HuggingFace model + if vit_name[-5:] == "in21k": + model = ViTModel(config).eval() + else: + model = ViTForImageClassification(config).eval() + model.load_state_dict(state_dict) + + # Check outputs on an image, prepared by ViTFeatureExtractor + feature_extractor = ViTFeatureExtractor(size=config.image_size) + encoding = feature_extractor(images=prepare_img(), return_tensors="pt") + pixel_values = encoding["pixel_values"] + outputs = model(pixel_values) + + if base_model: + timm_pooled_output = timm_model.forward_features(pixel_values) + assert timm_pooled_output.shape == outputs.pooler_output.shape + assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3) + else: + timm_logits = timm_model(pixel_values) + assert timm_logits.shape == outputs.logits.shape + assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {vit_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving feature extractor to {pytorch_dump_folder_path}") + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--vit_name", + default="vit_base_patch16_224", + type=str, + help="Name of the ViT timm model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + + args = parser.parse_args() + convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/vit/feature_extraction_vit.py b/src/transformers/models/vit/feature_extraction_vit.py new file mode 100644 index 000000000000..c4cf52ebb954 --- /dev/null +++ b/src/transformers/models/vit/feature_extraction_vit.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright Google AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for ViT.""" + +from typing import List, Optional, Union + +import numpy as np +from PIL import Image + +from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from ...file_utils import TensorType +from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): + r""" + Constructs a ViT feature extractor. + + This feature extractor inherits from :class:`~transformers.FeatureExtractionMixin` which contains most of the main + methods. Users should refer to this superclass for more information regarding those methods. + + Args: + image_mean (:obj:`int`, defaults to :obj:`[0.5, 0.5, 0.5]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (:obj:`int`, defaults to :obj:`[0.5, 0.5, 0.5]`): + The sequence of standard deviations for each channel, to be used when normalizing images. + do_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to normalize the input with mean and standard deviation. + do_resize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to resize the input to a certain :obj:`size`. + size (:obj:`int`, `optional`, defaults to 224): + Resize the input to the given size. Only has an effect if :obj:`do_resize` is set to :obj:`True`. + """ + + model_input_names = ["pixel_values"] + + def __init__(self, image_mean=None, image_std=None, do_normalize=True, do_resize=True, size=224, **kwargs): + super().__init__(**kwargs) + self.image_mean = [0.5, 0.5, 0.5] + self.image_std = [0.5, 0.5, 0.5] + self.do_normalize = do_normalize + self.do_resize = do_resize + self.size = size + + def __call__( + self, + images: Union[ + Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa + ], + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ) -> BatchFeature: + """ + Main method to prepare for the model one or several image(s). + + .. warning:: + + NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass + PIL images. + + Args: + images (:obj:`PIL.Image.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[PIL.Image.Image]`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + + return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.s + * :obj:`'jax'`: Return JAX :obj:`jnp.ndarray` objects. + + Returns: + :class:`~transformers.BatchFeature`: A :class:`~transformers.BatchFeature` with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + """ + # Input type checking for clearer error + valid_images = False + + # Check that images has a valid type + if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): + valid_images = True + elif isinstance(images, (list, tuple)): + if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]): + valid_images = True + + if not valid_images: + raise ValueError( + "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example)," + "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." + ) + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) + ) + + if not is_batched: + images = [images] + + # transformations (resizing + normalization) + if self.do_resize and self.size is not None: + images = [self.resize(image=image, size=self.size) for image in images] + if self.do_normalize: + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + + # return as BatchFeature + data = {"pixel_values": images} + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + return encoded_inputs diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py new file mode 100644 index 000000000000..99bd60c463ed --- /dev/null +++ b/src/transformers/models/vit/modeling_vit.py @@ -0,0 +1,629 @@ +# coding=utf-8 +# Copyright 2021 Google AI, Ross Weightman, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch ViT model. """ + + +import collections.abc +import math + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import logging +from .configuration_vit import ViTConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "ViTConfig" + +VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "nielsr/vit-base-patch16-224", + # See all ViT models at https://huggingface.co/models?filter=vit +] + + +# Inspired by +# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py +# From PyTorch internals +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + + +class ViTEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. + + """ + + def __init__(self, config): + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.patch_embeddings = PatchEmbeddings( + image_size=config.image_size, + patch_size=config.patch_size, + num_channels=config.num_channels, + embed_dim=config.hidden_size, + ) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, pixel_values): + batch_size = pixel_values.shape[0] + embeddings = self.patch_embeddings(pixel_values) + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +class PatchEmbeddings(nn.Module): + """ + Image to Patch Embedding. + + """ + + def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + super().__init__() + image_size = to_2tuple(image_size) + patch_size = to_2tuple(patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + # FIXME look at relaxing size constraints + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + x = self.projection(pixel_values).flatten(2).transpose(1, 2) + return x + + +class ViTSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, head_mask=None, output_attentions=False): + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class ViTSelfOutput(nn.Module): + """ + The residual connection is defined in VitLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class ViTAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = ViTSelfAttention(config) + self.output = ViTSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states, head_mask=None, output_attentions=False): + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ViTIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class ViTOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +class ViTLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ViTAttention(config) + self.intermediate = ViTIntermediate(config) + self.output = ViTOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, head_mask=None, output_attentions=False): + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + # TODO feedforward chunking not working for now + # layer_output = apply_chunking_to_forward( + # self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layer_output + # ) + + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output) + return layer_output + + +class ViTEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ViTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTConfig + base_model_prefix = "vit" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +VIT_START_DOCSTRING = r""" + This model is a PyTorch `torch.nn.Module `_ subclass. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config (:class:`~transformers.ViTConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +VIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + :class:`~transformers.ViTFeatureExtractor`. See :meth:`transformers.ViTFeatureExtractor.__call__` for + details. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.", + VIT_START_DOCSTRING, +) +class ViTModel(ViTPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = ViTEmbeddings(config) + self.encoder = ViTEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = ViTPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Examples:: + + >>> from transformers import ViTFeatureExtractor, ViTModel + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') + >>> model = ViTModel.from_pretrained('google/vit-base-patch16-224') + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class ViTPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@add_start_docstrings( + """ + ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + VIT_START_DOCSTRING, +) +class ViTForImageClassification(ViTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.vit = ViTModel(config, add_pooling_layer=False) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + self.init_weights() + + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + head_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples:: + + >>> from transformers import ViTFeatureExtractor, ViTForImageClassification + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') + >>> model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 139d229a879c..59649a3c02bd 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -302,6 +302,9 @@ def load_tf_weights_in_albert(*args, **kwargs): MODEL_FOR_CAUSAL_LM_MAPPING = None +MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None + + MODEL_FOR_MASKED_LM_MAPPING = None @@ -2512,6 +2515,32 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs): requires_pytorch(load_tf_weights_in_transfo_xl) +VIT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class ViTForImageClassification: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class ViTModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class ViTPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 7875ca953df0..d05d43f2046f 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -5,3 +5,8 @@ class ImageFeatureExtractionMixin: def __init__(self, *args, **kwargs): requires_vision(self) + + +class ViTFeatureExtractor: + def __init__(self, *args, **kwargs): + requires_vision(self) diff --git a/src/transformers/utils/imagenet_classes.py b/src/transformers/utils/imagenet_classes.py new file mode 100644 index 000000000000..73d831095c59 --- /dev/null +++ b/src/transformers/utils/imagenet_classes.py @@ -0,0 +1,1003 @@ +# ImageNet 2012 id's to class names +id2label = { + 0: "tench, Tinca tinca", + 1: "goldfish, Carassius auratus", + 2: "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", + 3: "tiger shark, Galeocerdo cuvieri", + 4: "hammerhead, hammerhead shark", + 5: "electric ray, crampfish, numbfish, torpedo", + 6: "stingray", + 7: "cock", + 8: "hen", + 9: "ostrich, Struthio camelus", + 10: "brambling, Fringilla montifringilla", + 11: "goldfinch, Carduelis carduelis", + 12: "house finch, linnet, Carpodacus mexicanus", + 13: "junco, snowbird", + 14: "indigo bunting, indigo finch, indigo bird, Passerina cyanea", + 15: "robin, American robin, Turdus migratorius", + 16: "bulbul", + 17: "jay", + 18: "magpie", + 19: "chickadee", + 20: "water ouzel, dipper", + 21: "kite", + 22: "bald eagle, American eagle, Haliaeetus leucocephalus", + 23: "vulture", + 24: "great grey owl, great gray owl, Strix nebulosa", + 25: "European fire salamander, Salamandra salamandra", + 26: "common newt, Triturus vulgaris", + 27: "eft", + 28: "spotted salamander, Ambystoma maculatum", + 29: "axolotl, mud puppy, Ambystoma mexicanum", + 30: "bullfrog, Rana catesbeiana", + 31: "tree frog, tree-frog", + 32: "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", + 33: "loggerhead, loggerhead turtle, Caretta caretta", + 34: "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", + 35: "mud turtle", + 36: "terrapin", + 37: "box turtle, box tortoise", + 38: "banded gecko", + 39: "common iguana, iguana, Iguana iguana", + 40: "American chameleon, anole, Anolis carolinensis", + 41: "whiptail, whiptail lizard", + 42: "agama", + 43: "frilled lizard, Chlamydosaurus kingi", + 44: "alligator lizard", + 45: "Gila monster, Heloderma suspectum", + 46: "green lizard, Lacerta viridis", + 47: "African chameleon, Chamaeleo chamaeleon", + 48: "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", + 49: "African crocodile, Nile crocodile, Crocodylus niloticus", + 50: "American alligator, Alligator mississipiensis", + 51: "triceratops", + 52: "thunder snake, worm snake, Carphophis amoenus", + 53: "ringneck snake, ring-necked snake, ring snake", + 54: "hognose snake, puff adder, sand viper", + 55: "green snake, grass snake", + 56: "king snake, kingsnake", + 57: "garter snake, grass snake", + 58: "water snake", + 59: "vine snake", + 60: "night snake, Hypsiglena torquata", + 61: "boa constrictor, Constrictor constrictor", + 62: "rock python, rock snake, Python sebae", + 63: "Indian cobra, Naja naja", + 64: "green mamba", + 65: "sea snake", + 66: "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", + 67: "diamondback, diamondback rattlesnake, Crotalus adamanteus", + 68: "sidewinder, horned rattlesnake, Crotalus cerastes", + 69: "trilobite", + 70: "harvestman, daddy longlegs, Phalangium opilio", + 71: "scorpion", + 72: "black and gold garden spider, Argiope aurantia", + 73: "barn spider, Araneus cavaticus", + 74: "garden spider, Aranea diademata", + 75: "black widow, Latrodectus mactans", + 76: "tarantula", + 77: "wolf spider, hunting spider", + 78: "tick", + 79: "centipede", + 80: "black grouse", + 81: "ptarmigan", + 82: "ruffed grouse, partridge, Bonasa umbellus", + 83: "prairie chicken, prairie grouse, prairie fowl", + 84: "peacock", + 85: "quail", + 86: "partridge", + 87: "African grey, African gray, Psittacus erithacus", + 88: "macaw", + 89: "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", + 90: "lorikeet", + 91: "coucal", + 92: "bee eater", + 93: "hornbill", + 94: "hummingbird", + 95: "jacamar", + 96: "toucan", + 97: "drake", + 98: "red-breasted merganser, Mergus serrator", + 99: "goose", + 100: "black swan, Cygnus atratus", + 101: "tusker", + 102: "echidna, spiny anteater, anteater", + 103: "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", + 104: "wallaby, brush kangaroo", + 105: "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", + 106: "wombat", + 107: "jellyfish", + 108: "sea anemone, anemone", + 109: "brain coral", + 110: "flatworm, platyhelminth", + 111: "nematode, nematode worm, roundworm", + 112: "conch", + 113: "snail", + 114: "slug", + 115: "sea slug, nudibranch", + 116: "chiton, coat-of-mail shell, sea cradle, polyplacophore", + 117: "chambered nautilus, pearly nautilus, nautilus", + 118: "Dungeness crab, Cancer magister", + 119: "rock crab, Cancer irroratus", + 120: "fiddler crab", + 121: "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", + 122: "American lobster, Northern lobster, Maine lobster, Homarus americanus", + 123: "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", + 124: "crayfish, crawfish, crawdad, crawdaddy", + 125: "hermit crab", + 126: "isopod", + 127: "white stork, Ciconia ciconia", + 128: "black stork, Ciconia nigra", + 129: "spoonbill", + 130: "flamingo", + 131: "little blue heron, Egretta caerulea", + 132: "American egret, great white heron, Egretta albus", + 133: "bittern", + 134: "crane", + 135: "limpkin, Aramus pictus", + 136: "European gallinule, Porphyrio porphyrio", + 137: "American coot, marsh hen, mud hen, water hen, Fulica americana", + 138: "bustard", + 139: "ruddy turnstone, Arenaria interpres", + 140: "red-backed sandpiper, dunlin, Erolia alpina", + 141: "redshank, Tringa totanus", + 142: "dowitcher", + 143: "oystercatcher, oyster catcher", + 144: "pelican", + 145: "king penguin, Aptenodytes patagonica", + 146: "albatross, mollymawk", + 147: "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", + 148: "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", + 149: "dugong, Dugong dugon", + 150: "sea lion", + 151: "Chihuahua", + 152: "Japanese spaniel", + 153: "Maltese dog, Maltese terrier, Maltese", + 154: "Pekinese, Pekingese, Peke", + 155: "Shih-Tzu", + 156: "Blenheim spaniel", + 157: "papillon", + 158: "toy terrier", + 159: "Rhodesian ridgeback", + 160: "Afghan hound, Afghan", + 161: "basset, basset hound", + 162: "beagle", + 163: "bloodhound, sleuthhound", + 164: "bluetick", + 165: "black-and-tan coonhound", + 166: "Walker hound, Walker foxhound", + 167: "English foxhound", + 168: "redbone", + 169: "borzoi, Russian wolfhound", + 170: "Irish wolfhound", + 171: "Italian greyhound", + 172: "whippet", + 173: "Ibizan hound, Ibizan Podenco", + 174: "Norwegian elkhound, elkhound", + 175: "otterhound, otter hound", + 176: "Saluki, gazelle hound", + 177: "Scottish deerhound, deerhound", + 178: "Weimaraner", + 179: "Staffordshire bullterrier, Staffordshire bull terrier", + 180: "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", + 181: "Bedlington terrier", + 182: "Border terrier", + 183: "Kerry blue terrier", + 184: "Irish terrier", + 185: "Norfolk terrier", + 186: "Norwich terrier", + 187: "Yorkshire terrier", + 188: "wire-haired fox terrier", + 189: "Lakeland terrier", + 190: "Sealyham terrier, Sealyham", + 191: "Airedale, Airedale terrier", + 192: "cairn, cairn terrier", + 193: "Australian terrier", + 194: "Dandie Dinmont, Dandie Dinmont terrier", + 195: "Boston bull, Boston terrier", + 196: "miniature schnauzer", + 197: "giant schnauzer", + 198: "standard schnauzer", + 199: "Scotch terrier, Scottish terrier, Scottie", + 200: "Tibetan terrier, chrysanthemum dog", + 201: "silky terrier, Sydney silky", + 202: "soft-coated wheaten terrier", + 203: "West Highland white terrier", + 204: "Lhasa, Lhasa apso", + 205: "flat-coated retriever", + 206: "curly-coated retriever", + 207: "golden retriever", + 208: "Labrador retriever", + 209: "Chesapeake Bay retriever", + 210: "German short-haired pointer", + 211: "vizsla, Hungarian pointer", + 212: "English setter", + 213: "Irish setter, red setter", + 214: "Gordon setter", + 215: "Brittany spaniel", + 216: "clumber, clumber spaniel", + 217: "English springer, English springer spaniel", + 218: "Welsh springer spaniel", + 219: "cocker spaniel, English cocker spaniel, cocker", + 220: "Sussex spaniel", + 221: "Irish water spaniel", + 222: "kuvasz", + 223: "schipperke", + 224: "groenendael", + 225: "malinois", + 226: "briard", + 227: "kelpie", + 228: "komondor", + 229: "Old English sheepdog, bobtail", + 230: "Shetland sheepdog, Shetland sheep dog, Shetland", + 231: "collie", + 232: "Border collie", + 233: "Bouvier des Flandres, Bouviers des Flandres", + 234: "Rottweiler", + 235: "German shepherd, German shepherd dog, German police dog, alsatian", + 236: "Doberman, Doberman pinscher", + 237: "miniature pinscher", + 238: "Greater Swiss Mountain dog", + 239: "Bernese mountain dog", + 240: "Appenzeller", + 241: "EntleBucher", + 242: "boxer", + 243: "bull mastiff", + 244: "Tibetan mastiff", + 245: "French bulldog", + 246: "Great Dane", + 247: "Saint Bernard, St Bernard", + 248: "Eskimo dog, husky", + 249: "malamute, malemute, Alaskan malamute", + 250: "Siberian husky", + 251: "dalmatian, coach dog, carriage dog", + 252: "affenpinscher, monkey pinscher, monkey dog", + 253: "basenji", + 254: "pug, pug-dog", + 255: "Leonberg", + 256: "Newfoundland, Newfoundland dog", + 257: "Great Pyrenees", + 258: "Samoyed, Samoyede", + 259: "Pomeranian", + 260: "chow, chow chow", + 261: "keeshond", + 262: "Brabancon griffon", + 263: "Pembroke, Pembroke Welsh corgi", + 264: "Cardigan, Cardigan Welsh corgi", + 265: "toy poodle", + 266: "miniature poodle", + 267: "standard poodle", + 268: "Mexican hairless", + 269: "timber wolf, grey wolf, gray wolf, Canis lupus", + 270: "white wolf, Arctic wolf, Canis lupus tundrarum", + 271: "red wolf, maned wolf, Canis rufus, Canis niger", + 272: "coyote, prairie wolf, brush wolf, Canis latrans", + 273: "dingo, warrigal, warragal, Canis dingo", + 274: "dhole, Cuon alpinus", + 275: "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", + 276: "hyena, hyaena", + 277: "red fox, Vulpes vulpes", + 278: "kit fox, Vulpes macrotis", + 279: "Arctic fox, white fox, Alopex lagopus", + 280: "grey fox, gray fox, Urocyon cinereoargenteus", + 281: "tabby, tabby cat", + 282: "tiger cat", + 283: "Persian cat", + 284: "Siamese cat, Siamese", + 285: "Egyptian cat", + 286: "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", + 287: "lynx, catamount", + 288: "leopard, Panthera pardus", + 289: "snow leopard, ounce, Panthera uncia", + 290: "jaguar, panther, Panthera onca, Felis onca", + 291: "lion, king of beasts, Panthera leo", + 292: "tiger, Panthera tigris", + 293: "cheetah, chetah, Acinonyx jubatus", + 294: "brown bear, bruin, Ursus arctos", + 295: "American black bear, black bear, Ursus americanus, Euarctos americanus", + 296: "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", + 297: "sloth bear, Melursus ursinus, Ursus ursinus", + 298: "mongoose", + 299: "meerkat, mierkat", + 300: "tiger beetle", + 301: "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", + 302: "ground beetle, carabid beetle", + 303: "long-horned beetle, longicorn, longicorn beetle", + 304: "leaf beetle, chrysomelid", + 305: "dung beetle", + 306: "rhinoceros beetle", + 307: "weevil", + 308: "fly", + 309: "bee", + 310: "ant, emmet, pismire", + 311: "grasshopper, hopper", + 312: "cricket", + 313: "walking stick, walkingstick, stick insect", + 314: "cockroach, roach", + 315: "mantis, mantid", + 316: "cicada, cicala", + 317: "leafhopper", + 318: "lacewing, lacewing fly", + 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", + 320: "damselfly", + 321: "admiral", + 322: "ringlet, ringlet butterfly", + 323: "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", + 324: "cabbage butterfly", + 325: "sulphur butterfly, sulfur butterfly", + 326: "lycaenid, lycaenid butterfly", + 327: "starfish, sea star", + 328: "sea urchin", + 329: "sea cucumber, holothurian", + 330: "wood rabbit, cottontail, cottontail rabbit", + 331: "hare", + 332: "Angora, Angora rabbit", + 333: "hamster", + 334: "porcupine, hedgehog", + 335: "fox squirrel, eastern fox squirrel, Sciurus niger", + 336: "marmot", + 337: "beaver", + 338: "guinea pig, Cavia cobaya", + 339: "sorrel", + 340: "zebra", + 341: "hog, pig, grunter, squealer, Sus scrofa", + 342: "wild boar, boar, Sus scrofa", + 343: "warthog", + 344: "hippopotamus, hippo, river horse, Hippopotamus amphibius", + 345: "ox", + 346: "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", + 347: "bison", + 348: "ram, tup", + 349: "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", + 350: "ibex, Capra ibex", + 351: "hartebeest", + 352: "impala, Aepyceros melampus", + 353: "gazelle", + 354: "Arabian camel, dromedary, Camelus dromedarius", + 355: "llama", + 356: "weasel", + 357: "mink", + 358: "polecat, fitch, foulmart, foumart, Mustela putorius", + 359: "black-footed ferret, ferret, Mustela nigripes", + 360: "otter", + 361: "skunk, polecat, wood pussy", + 362: "badger", + 363: "armadillo", + 364: "three-toed sloth, ai, Bradypus tridactylus", + 365: "orangutan, orang, orangutang, Pongo pygmaeus", + 366: "gorilla, Gorilla gorilla", + 367: "chimpanzee, chimp, Pan troglodytes", + 368: "gibbon, Hylobates lar", + 369: "siamang, Hylobates syndactylus, Symphalangus syndactylus", + 370: "guenon, guenon monkey", + 371: "patas, hussar monkey, Erythrocebus patas", + 372: "baboon", + 373: "macaque", + 374: "langur", + 375: "colobus, colobus monkey", + 376: "proboscis monkey, Nasalis larvatus", + 377: "marmoset", + 378: "capuchin, ringtail, Cebus capucinus", + 379: "howler monkey, howler", + 380: "titi, titi monkey", + 381: "spider monkey, Ateles geoffroyi", + 382: "squirrel monkey, Saimiri sciureus", + 383: "Madagascar cat, ring-tailed lemur, Lemur catta", + 384: "indri, indris, Indri indri, Indri brevicaudatus", + 385: "Indian elephant, Elephas maximus", + 386: "African elephant, Loxodonta africana", + 387: "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", + 388: "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", + 389: "barracouta, snoek", + 390: "eel", + 391: "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", + 392: "rock beauty, Holocanthus tricolor", + 393: "anemone fish", + 394: "sturgeon", + 395: "gar, garfish, garpike, billfish, Lepisosteus osseus", + 396: "lionfish", + 397: "puffer, pufferfish, blowfish, globefish", + 398: "abacus", + 399: "abaya", + 400: "academic gown, academic robe, judge's robe", + 401: "accordion, piano accordion, squeeze box", + 402: "acoustic guitar", + 403: "aircraft carrier, carrier, flattop, attack aircraft carrier", + 404: "airliner", + 405: "airship, dirigible", + 406: "altar", + 407: "ambulance", + 408: "amphibian, amphibious vehicle", + 409: "analog clock", + 410: "apiary, bee house", + 411: "apron", + 412: "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + 413: "assault rifle, assault gun", + 414: "backpack, back pack, knapsack, packsack, rucksack, haversack", + 415: "bakery, bakeshop, bakehouse", + 416: "balance beam, beam", + 417: "balloon", + 418: "ballpoint, ballpoint pen, ballpen, Biro", + 419: "Band Aid", + 420: "banjo", + 421: "bannister, banister, balustrade, balusters, handrail", + 422: "barbell", + 423: "barber chair", + 424: "barbershop", + 425: "barn", + 426: "barometer", + 427: "barrel, cask", + 428: "barrow, garden cart, lawn cart, wheelbarrow", + 429: "baseball", + 430: "basketball", + 431: "bassinet", + 432: "bassoon", + 433: "bathing cap, swimming cap", + 434: "bath towel", + 435: "bathtub, bathing tub, bath, tub", + 436: "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", + 437: "beacon, lighthouse, beacon light, pharos", + 438: "beaker", + 439: "bearskin, busby, shako", + 440: "beer bottle", + 441: "beer glass", + 442: "bell cote, bell cot", + 443: "bib", + 444: "bicycle-built-for-two, tandem bicycle, tandem", + 445: "bikini, two-piece", + 446: "binder, ring-binder", + 447: "binoculars, field glasses, opera glasses", + 448: "birdhouse", + 449: "boathouse", + 450: "bobsled, bobsleigh, bob", + 451: "bolo tie, bolo, bola tie, bola", + 452: "bonnet, poke bonnet", + 453: "bookcase", + 454: "bookshop, bookstore, bookstall", + 455: "bottlecap", + 456: "bow", + 457: "bow tie, bow-tie, bowtie", + 458: "brass, memorial tablet, plaque", + 459: "brassiere, bra, bandeau", + 460: "breakwater, groin, groyne, mole, bulwark, seawall, jetty", + 461: "breastplate, aegis, egis", + 462: "broom", + 463: "bucket, pail", + 464: "buckle", + 465: "bulletproof vest", + 466: "bullet train, bullet", + 467: "butcher shop, meat market", + 468: "cab, hack, taxi, taxicab", + 469: "caldron, cauldron", + 470: "candle, taper, wax light", + 471: "cannon", + 472: "canoe", + 473: "can opener, tin opener", + 474: "cardigan", + 475: "car mirror", + 476: "carousel, carrousel, merry-go-round, roundabout, whirligig", + 477: "carpenter's kit, tool kit", + 478: "carton", + 479: "car wheel", + 480: "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", + 481: "cassette", + 482: "cassette player", + 483: "castle", + 484: "catamaran", + 485: "CD player", + 486: "cello, violoncello", + 487: "cellular telephone, cellular phone, cellphone, cell, mobile phone", + 488: "chain", + 489: "chainlink fence", + 490: "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", + 491: "chain saw, chainsaw", + 492: "chest", + 493: "chiffonier, commode", + 494: "chime, bell, gong", + 495: "china cabinet, china closet", + 496: "Christmas stocking", + 497: "church, church building", + 498: "cinema, movie theater, movie theatre, movie house, picture palace", + 499: "cleaver, meat cleaver, chopper", + 500: "cliff dwelling", + 501: "cloak", + 502: "clog, geta, patten, sabot", + 503: "cocktail shaker", + 504: "coffee mug", + 505: "coffeepot", + 506: "coil, spiral, volute, whorl, helix", + 507: "combination lock", + 508: "computer keyboard, keypad", + 509: "confectionery, confectionary, candy store", + 510: "container ship, containership, container vessel", + 511: "convertible", + 512: "corkscrew, bottle screw", + 513: "cornet, horn, trumpet, trump", + 514: "cowboy boot", + 515: "cowboy hat, ten-gallon hat", + 516: "cradle", + 517: "crane", + 518: "crash helmet", + 519: "crate", + 520: "crib, cot", + 521: "Crock Pot", + 522: "croquet ball", + 523: "crutch", + 524: "cuirass", + 525: "dam, dike, dyke", + 526: "desk", + 527: "desktop computer", + 528: "dial telephone, dial phone", + 529: "diaper, nappy, napkin", + 530: "digital clock", + 531: "digital watch", + 532: "dining table, board", + 533: "dishrag, dishcloth", + 534: "dishwasher, dish washer, dishwashing machine", + 535: "disk brake, disc brake", + 536: "dock, dockage, docking facility", + 537: "dogsled, dog sled, dog sleigh", + 538: "dome", + 539: "doormat, welcome mat", + 540: "drilling platform, offshore rig", + 541: "drum, membranophone, tympan", + 542: "drumstick", + 543: "dumbbell", + 544: "Dutch oven", + 545: "electric fan, blower", + 546: "electric guitar", + 547: "electric locomotive", + 548: "entertainment center", + 549: "envelope", + 550: "espresso maker", + 551: "face powder", + 552: "feather boa, boa", + 553: "file, file cabinet, filing cabinet", + 554: "fireboat", + 555: "fire engine, fire truck", + 556: "fire screen, fireguard", + 557: "flagpole, flagstaff", + 558: "flute, transverse flute", + 559: "folding chair", + 560: "football helmet", + 561: "forklift", + 562: "fountain", + 563: "fountain pen", + 564: "four-poster", + 565: "freight car", + 566: "French horn, horn", + 567: "frying pan, frypan, skillet", + 568: "fur coat", + 569: "garbage truck, dustcart", + 570: "gasmask, respirator, gas helmet", + 571: "gas pump, gasoline pump, petrol pump, island dispenser", + 572: "goblet", + 573: "go-kart", + 574: "golf ball", + 575: "golfcart, golf cart", + 576: "gondola", + 577: "gong, tam-tam", + 578: "gown", + 579: "grand piano, grand", + 580: "greenhouse, nursery, glasshouse", + 581: "grille, radiator grille", + 582: "grocery store, grocery, food market, market", + 583: "guillotine", + 584: "hair slide", + 585: "hair spray", + 586: "half track", + 587: "hammer", + 588: "hamper", + 589: "hand blower, blow dryer, blow drier, hair dryer, hair drier", + 590: "hand-held computer, hand-held microcomputer", + 591: "handkerchief, hankie, hanky, hankey", + 592: "hard disc, hard disk, fixed disk", + 593: "harmonica, mouth organ, harp, mouth harp", + 594: "harp", + 595: "harvester, reaper", + 596: "hatchet", + 597: "holster", + 598: "home theater, home theatre", + 599: "honeycomb", + 600: "hook, claw", + 601: "hoopskirt, crinoline", + 602: "horizontal bar, high bar", + 603: "horse cart, horse-cart", + 604: "hourglass", + 605: "iPod", + 606: "iron, smoothing iron", + 607: "jack-o'-lantern", + 608: "jean, blue jean, denim", + 609: "jeep, landrover", + 610: "jersey, T-shirt, tee shirt", + 611: "jigsaw puzzle", + 612: "jinrikisha, ricksha, rickshaw", + 613: "joystick", + 614: "kimono", + 615: "knee pad", + 616: "knot", + 617: "lab coat, laboratory coat", + 618: "ladle", + 619: "lampshade, lamp shade", + 620: "laptop, laptop computer", + 621: "lawn mower, mower", + 622: "lens cap, lens cover", + 623: "letter opener, paper knife, paperknife", + 624: "library", + 625: "lifeboat", + 626: "lighter, light, igniter, ignitor", + 627: "limousine, limo", + 628: "liner, ocean liner", + 629: "lipstick, lip rouge", + 630: "Loafer", + 631: "lotion", + 632: "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + 633: "loupe, jeweler's loupe", + 634: "lumbermill, sawmill", + 635: "magnetic compass", + 636: "mailbag, postbag", + 637: "mailbox, letter box", + 638: "maillot", + 639: "maillot, tank suit", + 640: "manhole cover", + 641: "maraca", + 642: "marimba, xylophone", + 643: "mask", + 644: "matchstick", + 645: "maypole", + 646: "maze, labyrinth", + 647: "measuring cup", + 648: "medicine chest, medicine cabinet", + 649: "megalith, megalithic structure", + 650: "microphone, mike", + 651: "microwave, microwave oven", + 652: "military uniform", + 653: "milk can", + 654: "minibus", + 655: "miniskirt, mini", + 656: "minivan", + 657: "missile", + 658: "mitten", + 659: "mixing bowl", + 660: "mobile home, manufactured home", + 661: "Model T", + 662: "modem", + 663: "monastery", + 664: "monitor", + 665: "moped", + 666: "mortar", + 667: "mortarboard", + 668: "mosque", + 669: "mosquito net", + 670: "motor scooter, scooter", + 671: "mountain bike, all-terrain bike, off-roader", + 672: "mountain tent", + 673: "mouse, computer mouse", + 674: "mousetrap", + 675: "moving van", + 676: "muzzle", + 677: "nail", + 678: "neck brace", + 679: "necklace", + 680: "nipple", + 681: "notebook, notebook computer", + 682: "obelisk", + 683: "oboe, hautboy, hautbois", + 684: "ocarina, sweet potato", + 685: "odometer, hodometer, mileometer, milometer", + 686: "oil filter", + 687: "organ, pipe organ", + 688: "oscilloscope, scope, cathode-ray oscilloscope, CRO", + 689: "overskirt", + 690: "oxcart", + 691: "oxygen mask", + 692: "packet", + 693: "paddle, boat paddle", + 694: "paddlewheel, paddle wheel", + 695: "padlock", + 696: "paintbrush", + 697: "pajama, pyjama, pj's, jammies", + 698: "palace", + 699: "panpipe, pandean pipe, syrinx", + 700: "paper towel", + 701: "parachute, chute", + 702: "parallel bars, bars", + 703: "park bench", + 704: "parking meter", + 705: "passenger car, coach, carriage", + 706: "patio, terrace", + 707: "pay-phone, pay-station", + 708: "pedestal, plinth, footstall", + 709: "pencil box, pencil case", + 710: "pencil sharpener", + 711: "perfume, essence", + 712: "Petri dish", + 713: "photocopier", + 714: "pick, plectrum, plectron", + 715: "pickelhaube", + 716: "picket fence, paling", + 717: "pickup, pickup truck", + 718: "pier", + 719: "piggy bank, penny bank", + 720: "pill bottle", + 721: "pillow", + 722: "ping-pong ball", + 723: "pinwheel", + 724: "pirate, pirate ship", + 725: "pitcher, ewer", + 726: "plane, carpenter's plane, woodworking plane", + 727: "planetarium", + 728: "plastic bag", + 729: "plate rack", + 730: "plow, plough", + 731: "plunger, plumber's helper", + 732: "Polaroid camera, Polaroid Land camera", + 733: "pole", + 734: "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", + 735: "poncho", + 736: "pool table, billiard table, snooker table", + 737: "pop bottle, soda bottle", + 738: "pot, flowerpot", + 739: "potter's wheel", + 740: "power drill", + 741: "prayer rug, prayer mat", + 742: "printer", + 743: "prison, prison house", + 744: "projectile, missile", + 745: "projector", + 746: "puck, hockey puck", + 747: "punching bag, punch bag, punching ball, punchball", + 748: "purse", + 749: "quill, quill pen", + 750: "quilt, comforter, comfort, puff", + 751: "racer, race car, racing car", + 752: "racket, racquet", + 753: "radiator", + 754: "radio, wireless", + 755: "radio telescope, radio reflector", + 756: "rain barrel", + 757: "recreational vehicle, RV, R.V.", + 758: "reel", + 759: "reflex camera", + 760: "refrigerator, icebox", + 761: "remote control, remote", + 762: "restaurant, eating house, eating place, eatery", + 763: "revolver, six-gun, six-shooter", + 764: "rifle", + 765: "rocking chair, rocker", + 766: "rotisserie", + 767: "rubber eraser, rubber, pencil eraser", + 768: "rugby ball", + 769: "rule, ruler", + 770: "running shoe", + 771: "safe", + 772: "safety pin", + 773: "saltshaker, salt shaker", + 774: "sandal", + 775: "sarong", + 776: "sax, saxophone", + 777: "scabbard", + 778: "scale, weighing machine", + 779: "school bus", + 780: "schooner", + 781: "scoreboard", + 782: "screen, CRT screen", + 783: "screw", + 784: "screwdriver", + 785: "seat belt, seatbelt", + 786: "sewing machine", + 787: "shield, buckler", + 788: "shoe shop, shoe-shop, shoe store", + 789: "shoji", + 790: "shopping basket", + 791: "shopping cart", + 792: "shovel", + 793: "shower cap", + 794: "shower curtain", + 795: "ski", + 796: "ski mask", + 797: "sleeping bag", + 798: "slide rule, slipstick", + 799: "sliding door", + 800: "slot, one-armed bandit", + 801: "snorkel", + 802: "snowmobile", + 803: "snowplow, snowplough", + 804: "soap dispenser", + 805: "soccer ball", + 806: "sock", + 807: "solar dish, solar collector, solar furnace", + 808: "sombrero", + 809: "soup bowl", + 810: "space bar", + 811: "space heater", + 812: "space shuttle", + 813: "spatula", + 814: "speedboat", + 815: "spider web, spider's web", + 816: "spindle", + 817: "sports car, sport car", + 818: "spotlight, spot", + 819: "stage", + 820: "steam locomotive", + 821: "steel arch bridge", + 822: "steel drum", + 823: "stethoscope", + 824: "stole", + 825: "stone wall", + 826: "stopwatch, stop watch", + 827: "stove", + 828: "strainer", + 829: "streetcar, tram, tramcar, trolley, trolley car", + 830: "stretcher", + 831: "studio couch, day bed", + 832: "stupa, tope", + 833: "submarine, pigboat, sub, U-boat", + 834: "suit, suit of clothes", + 835: "sundial", + 836: "sunglass", + 837: "sunglasses, dark glasses, shades", + 838: "sunscreen, sunblock, sun blocker", + 839: "suspension bridge", + 840: "swab, swob, mop", + 841: "sweatshirt", + 842: "swimming trunks, bathing trunks", + 843: "swing", + 844: "switch, electric switch, electrical switch", + 845: "syringe", + 846: "table lamp", + 847: "tank, army tank, armored combat vehicle, armoured combat vehicle", + 848: "tape player", + 849: "teapot", + 850: "teddy, teddy bear", + 851: "television, television system", + 852: "tennis ball", + 853: "thatch, thatched roof", + 854: "theater curtain, theatre curtain", + 855: "thimble", + 856: "thresher, thrasher, threshing machine", + 857: "throne", + 858: "tile roof", + 859: "toaster", + 860: "tobacco shop, tobacconist shop, tobacconist", + 861: "toilet seat", + 862: "torch", + 863: "totem pole", + 864: "tow truck, tow car, wrecker", + 865: "toyshop", + 866: "tractor", + 867: "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", + 868: "tray", + 869: "trench coat", + 870: "tricycle, trike, velocipede", + 871: "trimaran", + 872: "tripod", + 873: "triumphal arch", + 874: "trolleybus, trolley coach, trackless trolley", + 875: "trombone", + 876: "tub, vat", + 877: "turnstile", + 878: "typewriter keyboard", + 879: "umbrella", + 880: "unicycle, monocycle", + 881: "upright, upright piano", + 882: "vacuum, vacuum cleaner", + 883: "vase", + 884: "vault", + 885: "velvet", + 886: "vending machine", + 887: "vestment", + 888: "viaduct", + 889: "violin, fiddle", + 890: "volleyball", + 891: "waffle iron", + 892: "wall clock", + 893: "wallet, billfold, notecase, pocketbook", + 894: "wardrobe, closet, press", + 895: "warplane, military plane", + 896: "washbasin, handbasin, washbowl, lavabo, wash-hand basin", + 897: "washer, automatic washer, washing machine", + 898: "water bottle", + 899: "water jug", + 900: "water tower", + 901: "whiskey jug", + 902: "whistle", + 903: "wig", + 904: "window screen", + 905: "window shade", + 906: "Windsor tie", + 907: "wine bottle", + 908: "wing", + 909: "wok", + 910: "wooden spoon", + 911: "wool, woolen, woollen", + 912: "worm fence, snake fence, snake-rail fence, Virginia fence", + 913: "wreck", + 914: "yawl", + 915: "yurt", + 916: "web site, website, internet site, site", + 917: "comic book", + 918: "crossword puzzle, crossword", + 919: "street sign", + 920: "traffic light, traffic signal, stoplight", + 921: "book jacket, dust cover, dust jacket, dust wrapper", + 922: "menu", + 923: "plate", + 924: "guacamole", + 925: "consomme", + 926: "hot pot, hotpot", + 927: "trifle", + 928: "ice cream, icecream", + 929: "ice lolly, lolly, lollipop, popsicle", + 930: "French loaf", + 931: "bagel, beigel", + 932: "pretzel", + 933: "cheeseburger", + 934: "hotdog, hot dog, red hot", + 935: "mashed potato", + 936: "head cabbage", + 937: "broccoli", + 938: "cauliflower", + 939: "zucchini, courgette", + 940: "spaghetti squash", + 941: "acorn squash", + 942: "butternut squash", + 943: "cucumber, cuke", + 944: "artichoke, globe artichoke", + 945: "bell pepper", + 946: "cardoon", + 947: "mushroom", + 948: "Granny Smith", + 949: "strawberry", + 950: "orange", + 951: "lemon", + 952: "fig", + 953: "pineapple, ananas", + 954: "banana", + 955: "jackfruit, jak, jack", + 956: "custard apple", + 957: "pomegranate", + 958: "hay", + 959: "carbonara", + 960: "chocolate sauce, chocolate syrup", + 961: "dough", + 962: "meat loaf, meatloaf", + 963: "pizza, pizza pie", + 964: "potpie", + 965: "burrito", + 966: "red wine", + 967: "espresso", + 968: "cup", + 969: "eggnog", + 970: "alp", + 971: "bubble", + 972: "cliff, drop, drop-off", + 973: "coral reef", + 974: "geyser", + 975: "lakeside, lakeshore", + 976: "promontory, headland, head, foreland", + 977: "sandbar, sand bar", + 978: "seashore, coast, seacoast, sea-coast", + 979: "valley, vale", + 980: "volcano", + 981: "ballplayer, baseball player", + 982: "groom, bridegroom", + 983: "scuba diver", + 984: "rapeseed", + 985: "daisy", + 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", + 987: "corn", + 988: "acorn", + 989: "hip, rose hip, rosehip", + 990: "buckeye, horse chestnut, conker", + 991: "coral fungus", + 992: "agaric", + 993: "gyromitra", + 994: "stinkhorn, carrion fungus", + 995: "earthstar", + 996: "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", + 997: "bolete", + 998: "ear, spike, capitulum", + 999: "toilet tissue, toilet paper, bathroom tissue", +} diff --git a/tests/test_feature_extraction_vit.py b/tests/test_feature_extraction_vit.py new file mode 100644 index 000000000000..d80b51841d0f --- /dev/null +++ b/tests/test_feature_extraction_vit.py @@ -0,0 +1,221 @@ +# coding=utf-8 +# Copyright 2021 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.file_utils import is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision + +from .test_feature_extraction_common import FeatureExtractionSavingTestMixin + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import ViTFeatureExtractor + + +class ViTFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + do_normalize=True, + do_resize=True, + size=18, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.image_mean = image_mean + self.image_std = image_std + self.do_normalize = do_normalize + self.do_resize = do_resize + self.size = size + + def prepare_feat_extract_dict(self): + return { + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_normalize": self.do_normalize, + "do_resize": self.do_resize, + "size": self.size, + } + + def prepare_inputs(self, equal_resolution=False, numpify=False, torchify=False): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time" + + if equal_resolution: + image_inputs = [] + for i in range(self.batch_size): + image_inputs.append( + np.random.randint( + 255, size=(self.num_channels, self.max_resolution, self.max_resolution), dtype=np.uint8 + ) + ) + else: + image_inputs = [] + for i in range(self.batch_size): + width, height = np.random.choice(np.arange(self.min_resolution, self.max_resolution), 2) + image_inputs.append(np.random.randint(255, size=(self.num_channels, width, height), dtype=np.uint8)) + + if not numpify and not torchify: + # PIL expects the channel dimension as last dimension + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + if torchify: + image_inputs = [torch.from_numpy(x) for x in image_inputs] + + return image_inputs + + +@require_torch +@require_vision +class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): + + feature_extraction_class = ViTFeatureExtractor if is_vision_available() else None + + def setUp(self): + self.feature_extract_tester = ViTFeatureExtractionTester(self) + + @property + def feat_extract_dict(self): + return self.feature_extract_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + self.assertTrue(hasattr(feature_extractor, "image_mean")) + self.assertTrue(hasattr(feature_extractor, "image_std")) + self.assertTrue(hasattr(feature_extractor, "do_normalize")) + self.assertTrue(hasattr(feature_extractor, "do_resize")) + self.assertTrue(hasattr(feature_extractor, "size")) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) + + def test_call_numpy(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random numpy tensors + image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) + + def test_call_pytorch(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PyTorch tensors + image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py index 352ef48c6b5f..7f65c25f6d6a 100644 --- a/tests/test_image_utils.py +++ b/tests/test_image_utils.py @@ -264,7 +264,9 @@ def test_normalize_image(self): # During the conversion rescale and channel first will be applied. expected = array.transpose(2, 0, 1).astype(np.float32) / 255.0 - expected = (expected - np.array(mean)[:, None, None]) / np.array(std)[:, None, None] + np_mean = np.array(mean).astype(np.float32)[:, None, None] + np_std = np.array(std).astype(np.float32)[:, None, None] + expected = (expected - np_mean) / np_std self.assertTrue(np.array_equal(normalized_image, expected)) def test_normalize_array(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 402691dc989e..9ce171e64938 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -34,6 +34,7 @@ from transformers import ( BERT_PRETRAINED_MODEL_ARCHIVE_LIST, MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, @@ -99,6 +100,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): elif model_class in [ *MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(), *MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(), + *MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.values(), ]: inputs_dict["labels"] = torch.zeros( self.model_tester.batch_size, dtype=torch.long, device=torch_device diff --git a/tests/test_modeling_vit.py b/tests/test_modeling_vit.py new file mode 100644 index 000000000000..ec060c9da68e --- /dev/null +++ b/tests/test_modeling_vit.py @@ -0,0 +1,365 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch ViT model. """ + + +import inspect +import unittest + +from transformers.file_utils import cached_property, is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ViTConfig, ViTForImageClassification, ViTModel + from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple + + +if is_vision_available(): + from PIL import Image + + from transformers import ViTFeatureExtractor + + +class ViTModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_sequence_label_size=10, + initializer_range=0.02, + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.scope = scope + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + + config = ViTConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + return config, pixel_values, labels + + def create_and_check_model(self, config, pixel_values, labels): + model = ViTModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(self.image_size) + patch_size = to_2tuple(self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.type_sequence_label_size + model = ViTForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + pixel_values, + labels, + ) = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class ViTModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as ViT does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = ( + ( + ViTModel, + ViTForImageClassification, + ) + if is_torch_available() + else () + ) + + test_pruning = False + test_torchscript = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = ViTModelTester(self) + self.config_tester = ConfigTester(self, config_class=ViTConfig, hidden_size=37) + + def test_config(self): + config = self.config_tester.config_class(**self.config_tester.inputs_dict) + # we omit vocab_size since ViT does not use this + self.config_tester.parent.assertTrue(hasattr(config, "hidden_size")) + self.config_tester.parent.assertTrue(hasattr(config, "num_attention_heads")) + self.config_tester.parent.assertTrue(hasattr(config, "num_hidden_layers")) + + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + + def test_inputs_embeds(self): + # ViT does not use inputs_embeds + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, torch.nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(self.model_tester.image_size) + patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 1 + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + if chunk_length is not None: + self.assertListEqual( + list(attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + if chunk_length is not None: + self.assertListEqual( + list(self_attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + # ViT has a different seq_length + image_size = to_2tuple(self.model_tester.image_size) + patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_length = num_patches + 1 + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in VIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = ViTModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/cats.png") + return image + + +@require_vision +class ViTModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") if is_vision_available() else None + + @slow + def test_inference_image_classification_head(self): + model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to(torch_device) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + # currently failing + # see https://discuss.pytorch.org/t/runtimeerror-expected-object-of-scalar-type-double-but-got-scalar-type-float-for-argument-2-weight/38961/2 + outputs = model(inputs["pixel_values"]) + # outputs = model(**inputs) + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device) + + self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))