diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index 5945a150be0c..464730108624 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -189,3 +189,52 @@ FlaxAutoModel .. autoclass:: transformers.FlaxAutoModel :members: + + +FlaxAutoModelForPreTraining +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAutoModelForPreTraining + :members: + + +FlaxAutoModelForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAutoModelForMaskedLM + :members: + + +FlaxAutoModelForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAutoModelForSequenceClassification + :members: + + +FlaxAutoModelForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAutoModelForQuestionAnswering + :members: + + +FlaxAutoModelForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAutoModelForTokenClassification + :members: + + +FlaxAutoModelForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAutoModelForMultipleChoice + :members: + + +FlaxAutoModelForNextSentencePrediction +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction + :members: diff --git a/hubconf.py b/hubconf.py index c2fa2d18a983..c23d5ed8ed2f 100644 --- a/hubconf.py +++ b/hubconf.py @@ -22,9 +22,10 @@ from transformers import ( AutoConfig, AutoModel, + AutoModelForCausalLM, + AutoModelForMaskedLM, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, - AutoModelWithLMHead, AutoTokenizer, add_start_docstrings, ) @@ -86,22 +87,41 @@ def model(*args, **kwargs): return AutoModel.from_pretrained(*args, **kwargs) -@add_start_docstrings(AutoModelWithLMHead.__doc__) -def modelWithLMHead(*args, **kwargs): +@add_start_docstrings(AutoModelForCausalLM.__doc__) +def modelForCausalLM(*args, **kwargs): r""" # Using torch.hub ! import torch - model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased') # Download model and configuration from huggingface.co and cache. - model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` - model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased', output_attentions=True) # Update configuration during loading + model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'gpt2') # Download model and configuration from huggingface.co and cache. + model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'gpt2', output_attentions=True) # Update configuration during loading assert model.config.output_attentions == True # Loading from a TF checkpoint file instead of a PyTorch model (slower) - config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') - model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) + config = AutoConfig.from_pretrained('./tf_model/gpt_tf_model_config.json') + model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './tf_model/gpt_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - return AutoModelWithLMHead.from_pretrained(*args, **kwargs) + return AutoModelForCausalLM.from_pretrained(*args, **kwargs) + + +@add_start_docstrings(AutoModelForMaskedLM.__doc__) +def modelForMaskedLM(*args, **kwargs): + r""" + # Using torch.hub ! + import torch + + model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'bert-base-uncased') # Download model and configuration from huggingface.co and cache. + model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'bert-base-uncased', output_attentions=True) # Update configuration during loading + assert model.config.output_attentions == True + # Loading from a TF checkpoint file instead of a PyTorch model (slower) + config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') + model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) + + """ + + return AutoModelForMaskedLM.from_pretrained(*args, **kwargs) @add_start_docstrings(AutoModelForSequenceClassification.__doc__) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f5954696e9ba..0cf332314d69 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1300,7 +1300,26 @@ # FLAX-backed objects if is_flax_available(): _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] - _import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"]) + _import_structure["models.auto"].extend( + [ + "FLAX_MODEL_FOR_MASKED_LM_MAPPING", + "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", + "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", + "FLAX_MODEL_FOR_PRETRAINING_MAPPING", + "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", + "FLAX_MODEL_MAPPING", + "FlaxAutoModel", + "FlaxAutoModelForMaskedLM", + "FlaxAutoModelForMultipleChoice", + "FlaxAutoModelForNextSentencePrediction", + "FlaxAutoModelForPreTraining", + "FlaxAutoModelForQuestionAnswering", + "FlaxAutoModelForSequenceClassification", + "FlaxAutoModelForTokenClassification", + ] + ) _import_structure["models.bert"].extend( [ "FlaxBertForMaskedLM", @@ -2402,7 +2421,24 @@ if is_flax_available(): from .modeling_flax_utils import FlaxPreTrainedModel - from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel + from .models.auto import ( + FLAX_MODEL_FOR_MASKED_LM_MAPPING, + FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + FLAX_MODEL_FOR_PRETRAINING_MAPPING, + FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + FLAX_MODEL_MAPPING, + FlaxAutoModel, + FlaxAutoModelForMaskedLM, + FlaxAutoModelForMultipleChoice, + FlaxAutoModelForNextSentencePrediction, + FlaxAutoModelForPreTraining, + FlaxAutoModelForQuestionAnswering, + FlaxAutoModelForSequenceClassification, + FlaxAutoModelForTokenClassification, + ) from .models.bert import ( FlaxBertForMaskedLM, FlaxBertForMultipleChoice, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 0a47a6cb2b80..8bf312231a75 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -82,7 +82,24 @@ ] if is_flax_available(): - _import_structure["modeling_flax_auto"] = ["FLAX_MODEL_MAPPING", "FlaxAutoModel"] + _import_structure["modeling_flax_auto"] = [ + "FLAX_MODEL_FOR_MASKED_LM_MAPPING", + "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", + "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", + "FLAX_MODEL_FOR_PRETRAINING_MAPPING", + "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", + "FLAX_MODEL_MAPPING", + "FlaxAutoModel", + "FlaxAutoModelForMaskedLM", + "FlaxAutoModelForMultipleChoice", + "FlaxAutoModelForNextSentencePrediction", + "FlaxAutoModelForPreTraining", + "FlaxAutoModelForQuestionAnswering", + "FlaxAutoModelForSequenceClassification", + "FlaxAutoModelForTokenClassification", + ] if TYPE_CHECKING: @@ -145,7 +162,24 @@ ) if is_flax_available(): - from .modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel + from .modeling_flax_auto import ( + FLAX_MODEL_FOR_MASKED_LM_MAPPING, + FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + FLAX_MODEL_FOR_PRETRAINING_MAPPING, + FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + FLAX_MODEL_MAPPING, + FlaxAutoModel, + FlaxAutoModelForMaskedLM, + FlaxAutoModelForMultipleChoice, + FlaxAutoModelForNextSentencePrediction, + FlaxAutoModelForPreTraining, + FlaxAutoModelForQuestionAnswering, + FlaxAutoModelForSequenceClassification, + FlaxAutoModelForTokenClassification, + ) else: import importlib diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py new file mode 100644 index 000000000000..1c96f13199e8 --- /dev/null +++ b/src/transformers/models/auto/auto_factory.py @@ -0,0 +1,420 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Factory function to build auto-model classes.""" + +import functools +import types + +from ...configuration_utils import PretrainedConfig +from .configuration_auto import AutoConfig, replace_list_option_in_docstrings + + +CLASS_DOCSTRING = """ + This is a generic model class that will be instantiated as one of the model classes of the library when created + with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the + :meth:`~transformers.BaseAutoModelClass.from_config` class method. + + This class cannot be instantiated directly using ``__init__()`` (throws an error). +""" + +FROM_CONFIG_DOCSTRING = """ + Instantiates one of the model classes of the library from a configuration. + + Note: + Loading a model from its configuration file does **not** load the model weights. It only affects the + model's configuration. Use :meth:`~transformers.BaseAutoModelClass.from_pretrained` to load the model + weights. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The model class to instantiate is selected based on the configuration class: + + List options + + Examples:: + + >>> from transformers import AutoConfig, BaseAutoModelClass + >>> # Download configuration from huggingface.co and cache. + >>> config = AutoConfig.from_pretrained('checkpoint_placeholder') + >>> model = BaseAutoModelClass.from_config(config) +""" + +FROM_PRETRAINED_TORCH_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either + passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, + by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: + + List options + + The model is set in evaluation mode by default using ``model.eval()`` (so for instance, dropout modules are + deactivated). To train the model, you should first set it back in training mode with ``model.train()`` + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In + this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided + as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in + a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + model_args (additional positional arguments, `optional`): + Will be passed along to the underlying model ``__init__()`` method. + config (:class:`~transformers.PretrainedConfig`, `optional`): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the `model id` string of a pretrained + model). + - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded + by supplying the save directory. + - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a + configuration JSON file named `config.json` is found in the directory. + state_dict (`Dict[str, torch.Tensor]`, `optional`): + A state dictionary to use instead of a state dictionary loaded from saved weights file. + + This option can be used if you want to create a model from a pretrained configuration but load your own + weights. In this case though, you should check if using + :func:`~transformers.PreTrainedModel.save_pretrained` and + :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): + Load the model weights from a TensorFlow checkpoint save file (see docstring of + ``pretrained_model_name_or_path`` argument). + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (:obj:`Dict[str, str], `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + kwargs (additional keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or + automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the + underlying model's ``__init__`` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class + initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of + ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute + with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration + attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) + >>> config = AutoConfig.from_pretrained('./tf_model/shortcut_placeholder_tf_model_config.json') + >>> model = BaseAutoModelClass.from_pretrained('./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index', from_tf=True, config=config) +""" + +FROM_PRETRAINED_TF_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either + passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, + by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: + + List options + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In + this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided + as ``config`` argument. This loading path is slower than converting the PyTorch model in a + TensorFlow model using the provided conversion scripts and loading the TensorFlow model + afterwards. + model_args (additional positional arguments, `optional`): + Will be passed along to the underlying model ``__init__()`` method. + config (:class:`~transformers.PretrainedConfig`, `optional`): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the `model id` string of a pretrained + model). + - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded + by supplying the save directory. + - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a + configuration JSON file named `config.json` is found in the directory. + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + ``pretrained_model_name_or_path`` argument). + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (:obj:`Dict[str, str], `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + kwargs (additional keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or + automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the + underlying model's ``__init__`` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class + initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of + ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute + with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration + attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) + >>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json') + >>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config) +""" + +FROM_PRETRAINED_FLAX_DOCSTRING = """ + Instantiate one of the model classes of the library from a pretrained model. + + The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either + passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, + by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: + + List options + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In + this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided + as ``config`` argument. This loading path is slower than converting the PyTorch model in a + TensorFlow model using the provided conversion scripts and loading the TensorFlow model + afterwards. + model_args (additional positional arguments, `optional`): + Will be passed along to the underlying model ``__init__()`` method. + config (:class:`~transformers.PretrainedConfig`, `optional`): + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the `model id` string of a pretrained + model). + - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded + by supplying the save directory. + - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a + configuration JSON file named `config.json` is found in the directory. + cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + ``pretrained_model_name_or_path`` argument). + force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (:obj:`Dict[str, str], `optional`): + A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to only look at local files (e.g., not try downloading the model). + revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any + identifier allowed by git. + kwargs (additional keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or + automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the + underlying model's ``__init__`` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class + initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of + ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute + with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration + attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + >>> from transformers import AutoConfig, BaseAutoModelClass + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder') + + >>> # Update configuration during loading + >>> model = BaseAutoModelClass.from_pretrained('checkpoint_placeholder', output_attentions=True) + >>> model.config.output_attentions + True + + >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) + >>> config = AutoConfig.from_pretrained('./pt_model/shortcut_placeholder_pt_model_config.json') + >>> model = BaseAutoModelClass.from_pretrained('./pt_model/shortcut_placeholder_pytorch_model.bin', from_pt=True, config=config) +""" + + +class _BaseAutoModelClass: + # Base class for auto models. + _model_mapping = None + + def __init__(self): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_config(config)` methods." + ) + + def from_config(cls, config, **kwargs): + if type(config) in cls._model_mapping.keys(): + return cls._model_mapping[type(config)](config, **kwargs) + raise ValueError( + f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" + f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." + ) + + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + config = kwargs.pop("config", None) + kwargs["_from_auto"] = True + if not isinstance(config, PretrainedConfig): + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) + + if type(config) in cls._model_mapping.keys(): + return cls._model_mapping[type(config)].from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) + raise ValueError( + f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" + f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." + ) + + +def copy_func(f): + """ Returns a copy of a function f.""" + # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard) + g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = f.__kwdefaults__ + return g + + +def insert_head_doc(docstring, head_doc=""): + if len(head_doc) > 0: + return docstring.replace( + "one of the model classes of the library ", + f"one of the model classes of the library (with a {head_doc} head) ", + ) + return docstring.replace( + "one of the model classes of the library ", "one of the base model classes of the library " + ) + + +def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-cased", head_doc=""): + # Create a new class with the right name from the base class + new_class = types.new_class(name, (_BaseAutoModelClass,)) + new_class._model_mapping = model_mapping + class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc) + new_class.__doc__ = class_docstring.replace("BaseAutoModelClass", name) + + # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't + # have a specific docstrings for them. + from_config = copy_func(_BaseAutoModelClass.from_config) + from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc) + from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name) + from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example) + from_config.__doc__ = from_config_docstring + from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config) + new_class.from_config = classmethod(from_config) + + if name.startswith("TF"): + from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING + elif name.startswith("Flax"): + from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING + else: + from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING + from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained) + from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc) + from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name) + from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example) + shortcut = checkpoint_for_example.split("/")[-1].split("-")[0] + from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut) + from_pretrained.__doc__ = from_pretrained_docstring + from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained) + new_class.from_pretrained = classmethod(from_pretrained) + return new_class diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index b32140c7c1c1..b6bf0ad22395 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -256,8 +256,8 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): if config in config_to_class } lines = [ - f"{indent}- **{model_type}** -- :class:`~transformers.{cls_name}` ({MODEL_NAMES_MAPPING[model_type]} model)" - for model_type, cls_name in model_type_to_name.items() + f"{indent}- **{model_type}** -- :class:`~transformers.{model_type_to_name[model_type]}` ({MODEL_NAMES_MAPPING[model_type]} model)" + for model_type in sorted(model_type_to_name.keys()) ] else: config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()} @@ -265,8 +265,8 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items() } lines = [ - f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{cls_name}` ({config_to_model_name[config_name]} model)" - for config_name, cls_name in config_to_name.items() + f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{config_to_name[config_name]}` ({config_to_model_name[config_name]} model)" + for config_name in sorted(config_to_name.keys()) ] return "\n".join(lines) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index aecd7aa96715..ccebed05280a 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -18,8 +18,6 @@ import warnings from collections import OrderedDict -from ...configuration_utils import PretrainedConfig -from ...file_utils import add_start_docstrings from ...utils import logging from ..albert.modeling_albert import ( AlbertForMaskedLM, @@ -269,9 +267,9 @@ XLNetLMHeadModel, XLNetModel, ) +from .auto_factory import auto_class_factory from .configuration_auto import ( AlbertConfig, - AutoConfig, BartConfig, BertConfig, BertGenerationConfig, @@ -320,7 +318,6 @@ XLMProphetNetConfig, XLMRobertaConfig, XLNetConfig, - replace_list_option_in_docstrings, ) @@ -684,1290 +681,84 @@ ] ) -AUTO_MODEL_PRETRAINED_DOCSTRING = r""" - The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either - passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, - by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: +AutoModel = auto_class_factory("AutoModel", MODEL_MAPPING) - List options - - The model is set in evaluation mode by default using ``model.eval()`` (so for instance, dropout modules are - deactivated). To train the model, you should first set it back in training mode with ``model.train()`` - - Args: - pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): - Can be either: - - - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under - a user or organization name, like ``dbmdz/bert-base-german-cased``. - - A path to a `directory` containing model weights saved using - :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In - this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided - as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in - a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. - model_args (additional positional arguments, `optional`): - Will be passed along to the underlying model ``__init__()`` method. - config (:class:`~transformers.PretrainedConfig`, `optional`): - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - The model is a model provided by the library (loaded with the `model id` string of a pretrained - model). - - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded - by supplying the save directory. - - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a - configuration JSON file named `config.json` is found in the directory. - state_dict (`Dict[str, torch.Tensor]`, `optional`): - A state dictionary to use instead of a state dictionary loaded from saved weights file. - - This option can be used if you want to create a model from a pretrained configuration but load your own - weights. In this case though, you should check if using - :func:`~transformers.PreTrainedModel.save_pretrained` and - :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. - cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): - Load the model weights from a TensorFlow checkpoint save file (see docstring of - ``pretrained_model_name_or_path`` argument). - force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (:obj:`Dict[str, str], `optional`): - A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to only look at local files (e.g., not try downloading the model). - revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any - identifier allowed by git. - kwargs (additional keyword arguments, `optional`): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or - automatically loaded: - - - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the - underlying model's ``__init__`` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class - initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of - ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute - with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration - attribute will be passed to the underlying model's ``__init__`` function. -""" - - -class AutoModel: - r""" - This is a generic model class that will be instantiated as one of the base model classes of the library when - created with the :meth:`~transformers.AutoModel.from_pretrained` class method or the - :meth:`~transformers.AutoModel.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModel is designed to be instantiated " - "using the `AutoModel.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModel.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the base model classes of the library from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModel.from_pretrained` to load the model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModel - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = AutoModel.from_config(config) - """ - if type(config) in MODEL_MAPPING.keys(): - return MODEL_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_MAPPING) - @add_start_docstrings( - "Instantiate one of the base model classes of the library from a pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - - Examples:: - - >>> from transformers import AutoConfig, AutoModel - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModel.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = AutoModel.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') - >>> model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_MAPPING.keys(): - return MODEL_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}." - ) - - -class AutoModelForPreTraining: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with the - architecture used for pretraining this model---when created with the - :meth:`~transformers.AutoModelForPreTraining.from_pretrained` class method or the - :meth:`~transformers.AutoModelForPreTraining.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelForPreTraining is designed to be instantiated " - "using the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelForPreTraining.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_PRETRAINING_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with the architecture used for pretraining this - model---from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelForPreTraining.from_pretrained` to load the model - weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: +AutoModelForPreTraining = auto_class_factory( + "AutoModelForPreTraining", MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining" +) - >>> from transformers import AutoConfig, AutoModelForPreTraining - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = AutoModelForPreTraining.from_config(config) - """ - if type(config) in MODEL_FOR_PRETRAINING_MAPPING.keys(): - return MODEL_FOR_PRETRAINING_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())}." - ) +# Private on puprose, the public class will add the deprecation warnings. +_AutoModelWithLMHead = auto_class_factory( + "AutoModelWithLMHead", MODEL_WITH_LM_HEAD_MAPPING, head_doc="language modeling" +) - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_PRETRAINING_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with the architecture used for pretraining this ", - "model---from a pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: +AutoModelForCausalLM = auto_class_factory( + "AutoModelForCausalLM", MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling" +) - >>> from transformers import AutoConfig, AutoModelForPreTraining +AutoModelForMaskedLM = auto_class_factory( + "AutoModelForMaskedLM", MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling" +) - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelForPreTraining.from_pretrained('bert-base-uncased') +AutoModelForSeq2SeqLM = auto_class_factory( + "AutoModelForSeq2SeqLM", + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + head_doc="sequence-to-sequence language modeling", + checkpoint_for_example="t5-base", +) - >>> # Update configuration during loading - >>> model = AutoModelForPreTraining.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True +AutoModelForSequenceClassification = auto_class_factory( + "AutoModelForSequenceClassification", MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, head_doc="sequence classification" +) - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') - >>> model = AutoModelForPreTraining.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) +AutoModelForQuestionAnswering = auto_class_factory( + "AutoModelForQuestionAnswering", MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering" +) - if type(config) in MODEL_FOR_PRETRAINING_MAPPING.keys(): - return MODEL_FOR_PRETRAINING_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_PRETRAINING_MAPPING.keys())}." - ) +AutoModelForTableQuestionAnswering = auto_class_factory( + "AutoModelForTableQuestionAnswering", + MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, + head_doc="table question answering", + checkpoint_for_example="google/tapas-base-finetuned-wtq", +) +AutoModelForTokenClassification = auto_class_factory( + "AutoModelForTokenClassification", MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification" +) -class AutoModelWithLMHead: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a - language modeling head---when created with the :meth:`~transformers.AutoModelWithLMHead.from_pretrained` class - method or the :meth:`~transformers.AutoModelWithLMHead.from_config` class method. +AutoModelForMultipleChoice = auto_class_factory( + "AutoModelForMultipleChoice", MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice" +) - This class cannot be instantiated directly using ``__init__()`` (throws an error). +AutoModelForNextSentencePrediction = auto_class_factory( + "AutoModelForNextSentencePrediction", + MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + head_doc="next sentence prediction", +) - .. warning:: +AutoModelForImageClassification = auto_class_factory( + "AutoModelForImageClassification", MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, head_doc="image classification" +) - This class is deprecated and will be removed in a future version. Please use - :class:`~transformers.AutoModelForCausalLM` for causal language models, - :class:`~transformers.AutoModelForMaskedLM` for masked language models and - :class:`~transformers.AutoModelForSeq2SeqLM` for encoder-decoder models. - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelWithLMHead is designed to be instantiated " - "using the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelWithLMHead.from_config(config)` methods." - ) +class AutoModelWithLMHead(_AutoModelWithLMHead): @classmethod - @replace_list_option_in_docstrings(MODEL_WITH_LM_HEAD_MAPPING, use_model_types=False) def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a language modeling head---from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelWithLMHead.from_pretrained` to load the model - weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModelWithLMHead - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = AutoModelWithLMHead.from_config(config) - """ warnings.warn( "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) - if type(config) in MODEL_WITH_LM_HEAD_MAPPING.keys(): - return MODEL_WITH_LM_HEAD_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())}." - ) + return super().from_config(config) @classmethod - @replace_list_option_in_docstrings(MODEL_WITH_LM_HEAD_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a language modeling head---from a pretrained ", - "model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, AutoModelWithLMHead - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelWithLMHead.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = AutoModelWithLMHead.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') - >>> model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ warnings.warn( "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_WITH_LM_HEAD_MAPPING.keys(): - return MODEL_WITH_LM_HEAD_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())}." - ) - - -class AutoModelForCausalLM: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a causal - language modeling head---when created with the :meth:`~transformers.AutoModelForCausalLM.from_pretrained` class - method or the :meth:`~transformers.AutoModelForCausalLM.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelForCausalLM is designed to be instantiated " - "using the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelForCausalLM.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_CAUSAL_LM_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a causal language modeling head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelForCausalLM.from_pretrained` to load the model - weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModelForCausalLM - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('gpt2') - >>> model = AutoModelForCausalLM.from_config(config) - """ - if type(config) in MODEL_FOR_CAUSAL_LM_MAPPING.keys(): - return MODEL_FOR_CAUSAL_LM_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_CAUSAL_LM_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a causal language modeling head---from a " - "pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, AutoModelForCausalLM - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelForCausalLM.from_pretrained('gpt2') - - >>> # Update configuration during loading - >>> model = AutoModelForCausalLM.from_pretrained('gpt2', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/gpt2_tf_model_config.json') - >>> model = AutoModelForCausalLM.from_pretrained('./tf_model/gpt2_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_FOR_CAUSAL_LM_MAPPING.keys(): - return MODEL_FOR_CAUSAL_LM_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_CAUSAL_LM_MAPPING.keys())}." - ) - - -class AutoModelForMaskedLM: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a masked - language modeling head---when created with the :meth:`~transformers.AutoModelForMaskedLM.from_pretrained` class - method or the :meth:`~transformers.AutoModelForMaskedLM.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelForMaskedLM is designed to be instantiated " - "using the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelForMaskedLM.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_MASKED_LM_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a masked language modeling head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelForMaskedLM.from_pretrained` to load the model - weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModelForMaskedLM - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = AutoModelForMaskedLM.from_config(config) - """ - if type(config) in MODEL_FOR_MASKED_LM_MAPPING.keys(): - return MODEL_FOR_MASKED_LM_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_MASKED_LM_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a masked language modeling head---from a " - "pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, AutoModelForMaskedLM - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') - >>> model = AutoModelForMaskedLM.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_FOR_MASKED_LM_MAPPING.keys(): - return MODEL_FOR_MASKED_LM_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_MASKED_LM_MAPPING.keys())}." - ) - - -class AutoModelForSeq2SeqLM: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a - sequence-to-sequence language modeling head---when created with the - :meth:`~transformers.AutoModelForSeq2SeqLM.from_pretrained` class method or the - :meth:`~transformers.AutoModelForSeq2SeqLM.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelForSeq2SeqLM is designed to be instantiated " - "using the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelForSeq2SeqLM.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a sequence-to-sequence language modeling - head---from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelForSeq2SeqLM.from_pretrained` to load the model - weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModelForSeq2SeqLM - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('t5') - >>> model = AutoModelForSeq2SeqLM.from_config(config) - """ - if type(config) in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys(): - return MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a sequence-to-sequence language modeling " - "head---from a pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, AutoModelForSeq2SeqLM - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelForSeq2SeqLM.from_pretrained('t5-base') - - >>> # Update configuration during loading - >>> model = AutoModelForSeq2SeqLM.from_pretrained('t5-base', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/t5_tf_model_config.json') - >>> model = AutoModelForSeq2SeqLM.from_pretrained('./tf_model/t5_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys(): - return MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())}." - ) - - -class AutoModelForSequenceClassification: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a - sequence classification head---when created with the - :meth:`~transformers.AutoModelForSequenceClassification.from_pretrained` class method or the - :meth:`~transformers.AutoModelForSequenceClassification.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelForSequenceClassification is designed to be instantiated " - "using the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelForSequenceClassification.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a sequence classification head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelForSequenceClassification.from_pretrained` to load - the model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModelForSequenceClassification - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = AutoModelForSequenceClassification.from_config(config) - """ - if type(config) in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys(): - return MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a sequence classification head---from a " - "pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, AutoModelForSequenceClassification - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') - >>> model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys(): - return MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys())}." - ) - - -class AutoModelForQuestionAnswering: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a - question answering head---when created with the :meth:`~transformers.AutoModeForQuestionAnswering.from_pretrained` - class method or the :meth:`~transformers.AutoModelForQuestionAnswering.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelForQuestionAnswering is designed to be instantiated " - "using the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelForQuestionAnswering.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_QUESTION_ANSWERING_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a question answering head---from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelForQuestionAnswering.from_pretrained` to load the - model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModelForQuestionAnswering - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = AutoModelForQuestionAnswering.from_config(config) - """ - if type(config) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys(): - return MODEL_FOR_QUESTION_ANSWERING_MAPPING[type(config)](config) - - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_QUESTION_ANSWERING_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a question answering head---from a " - "pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, AutoModelForQuestionAnswering - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') - >>> model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys(): - return MODEL_FOR_QUESTION_ANSWERING_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())}." - ) - - -class AutoModelForTableQuestionAnswering: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a table - question answering head---when created with the - :meth:`~transformers.AutoModeForTableQuestionAnswering.from_pretrained` class method or the - :meth:`~transformers.AutoModelForTableQuestionAnswering.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelForQuestionAnswering is designed to be instantiated " - "using the `AutoModelForTableQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelForTableQuestionAnswering.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a table question answering head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelForTableQuestionAnswering.from_pretrained` to load - the model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModelForTableQuestionAnswering - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('google/tapas-base-finetuned-wtq') - >>> model = AutoModelForTableQuestionAnswering.from_config(config) - """ - if type(config) in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys(): - return MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING[type(config)](config) - - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a table question answering head---from a " - "pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, AutoModelForTableQuestionAnswering - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelForTableQuestionAnswering.from_pretrained('google/tapas-base-finetuned-wtq') - - >>> # Update configuration during loading - >>> model = AutoModelForTableQuestionAnswering.from_pretrained('google/tapas-base-finetuned-wtq', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/tapas_tf_checkpoint.json') - >>> model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/tapas_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys(): - return MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.keys())}." - ) - - -class AutoModelForTokenClassification: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a token - classification head---when created with the :meth:`~transformers.AutoModelForTokenClassification.from_pretrained` - class method or the :meth:`~transformers.AutoModelForTokenClassification.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelForTokenClassification is designed to be instantiated " - "using the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelForTokenClassification.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a token classification head---from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelForTokenClassification.from_pretrained` to load - the model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModelForTokenClassification - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = AutoModelForTokenClassification.from_config(config) - """ - if type(config) in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys(): - return MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING[type(config)](config) - - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a token classification head---from a " - "pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, AutoModelForTokenClassification - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') - >>> model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys(): - return MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())}." - ) - - -class AutoModelForMultipleChoice: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a - multiple choice classification head---when created with the - :meth:`~transformers.AutoModelForMultipleChoice.from_pretrained` class method or the - :meth:`~transformers.AutoModelForMultipleChoice.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelForMultipleChoice is designed to be instantiated " - "using the `AutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelForMultipleChoice.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_MULTIPLE_CHOICE_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a multiple choice classification head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelForMultipleChoice.from_pretrained` to load the - model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModelForMultipleChoice - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = AutoModelForMultipleChoice.from_config(config) - """ - if type(config) in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys(): - return MODEL_FOR_MULTIPLE_CHOICE_MAPPING[type(config)](config) - - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_MULTIPLE_CHOICE_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a multiple choice classification head---from a " - "pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, AutoModelForMultipleChoice - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelForMultipleChoice.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = AutoModelForMultipleChoice.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') - >>> model = AutoModelForMultipleChoice.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys(): - return MODEL_FOR_MULTIPLE_CHOICE_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys())}." - ) - - -class AutoModelForNextSentencePrediction: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a next - sentence prediction head---when created with the - :meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` class method or the - :meth:`~transformers.AutoModelForNextSentencePrediction.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelForNextSentencePrediction is designed to be instantiated " - "using the `AutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelForNextSentencePrediction.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a multiple choice classification head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelForNextSentencePrediction.from_pretrained` to load - the model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModelForNextSentencePrediction - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = AutoModelForNextSentencePrediction.from_config(config) - """ - if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): - return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config) - - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a multiple choice classification head---from a " - "pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, AutoModelForNextSentencePrediction - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = AutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') - >>> model = AutoModelForNextSentencePrediction.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): - return MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys())}." - ) - - -class AutoModelForImageClassification: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with an image - classification head---when created with the :meth:`~transformers.AutoModelForImageClassification.from_pretrained` - class method or the :meth:`~transformers.AutoModelForImageClassification.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "AutoModelForImageClassification is designed to be instantiated " - "using the `AutoModelForImageClassification.from_pretrained(pretrained_model_name_or_path)` or " - "`AutoModelForImageClassification.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with an image classification head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.AutoModelForImageClassification.from_pretrained` to load - the model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, AutoModelForImageClassification - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('google/vit_base_patch16_224') - >>> model = AutoModelForImageClassification.from_config(config) - """ - if type(config) in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys(): - return MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING[type(config)](config) - raise ValueError( - "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" - "Model type should be one of {}.".format( - config.__class__, - cls.__name__, - ", ".join(c.__name__ for c in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()), - ) - ) - - @classmethod - @replace_list_option_in_docstrings(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with an image classification head---from a " - "pretrained model.", - AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, AutoModelForImageClassification - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = AutoModelForImageClassification.from_pretrained('google/vit_base_patch16_224') - - >>> # Update configuration during loading - >>> model = AutoModelForImageClassification.from_pretrained('google/vit_base_patch16_224', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_json_file('./tf_model/vit_tf_model_config.json') - >>> model = AutoModelForImageClassification.from_pretrained('./tf_model/vit_tf_checkpoint.ckpt.index', from_tf=True, config=config) - """ - config = kwargs.pop("config", None) - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys(): - return MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - "Unrecognized configuration class {} for this kind of AutoModel: {}.\n" - "Model type should be one of {}.".format( - config.__class__, - cls.__name__, - ", ".join(c.__name__ for c in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()), - ) - ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index f91cc496e6b6..042612d0a529 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -17,11 +17,20 @@ from collections import OrderedDict -from ...configuration_utils import PretrainedConfig from ...utils import logging -from ..bert.modeling_flax_bert import FlaxBertModel +from ..bert.modeling_flax_bert import ( + FlaxBertForMaskedLM, + FlaxBertForMultipleChoice, + FlaxBertForNextSentencePrediction, + FlaxBertForPreTraining, + FlaxBertForQuestionAnswering, + FlaxBertForSequenceClassification, + FlaxBertForTokenClassification, + FlaxBertModel, +) from ..roberta.modeling_flax_roberta import FlaxRobertaModel -from .configuration_auto import AutoConfig, BertConfig, RobertaConfig +from .auto_factory import auto_class_factory +from .configuration_auto import BertConfig, RobertaConfig logger = logging.get_logger(__name__) @@ -29,140 +38,90 @@ FLAX_MODEL_MAPPING = OrderedDict( [ + # Base model mapping (RobertaConfig, FlaxRobertaModel), (BertConfig, FlaxBertModel), ] ) +FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( + [ + # Model for pre-training mapping + (BertConfig, FlaxBertForPreTraining), + ] +) -class FlaxAutoModel(object): - r""" - :class:`~transformers.FlaxAutoModel` is a generic model class that will be instantiated as one of the base model - classes of the library when created with the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or the - `FlaxAutoModel.from_config(config)` class methods. - - This class cannot be instantiated using `__init__()` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "FlaxAutoModel is designed to be instantiated " - "using the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or " - "`FlaxAutoModel.from_config(config)` methods." - ) +FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( + [ + # Model for Masked LM mapping + (BertConfig, FlaxBertForMaskedLM), + ] +) + +FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( + [ + # Model for Sequence Classification mapping + (BertConfig, FlaxBertForSequenceClassification), + ] +) - @classmethod - def from_config(cls, config): - r""" - Instantiates one of the base model classes of the library from a configuration. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - - isInstance of `roberta` configuration class: :class:`~transformers.FlaxRobertaModel` (RoBERTa model) - - isInstance of `bert` configuration class: :class:`~transformers.FlaxBertModel` (Bert model - - Examples:: - - config = BertConfig.from_pretrained('bert-base-uncased') - # Download configuration from huggingface.co and cache. - model = FlaxAutoModel.from_config(config) - # E.g. model was saved using `save_pretrained('./test/saved_model/')` - """ - for config_class, model_class in FLAX_MODEL_MAPPING.items(): - if isinstance(config, config_class): - return model_class(config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} " - f"for this kind of FlaxAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}." - ) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Instantiates one of the base model classes of the library from a pre-trained model configuration. - - The `from_pretrained()` method takes care of returning the correct model class instance based on the - `model_type` property of the config object, or when it's missing, falling back to using pattern matching on the - `pretrained_model_name_or_path` string. - - The base model class to instantiate is selected as the first pattern matching in the - `pretrained_model_name_or_path` string (in the following order): - - - contains `roberta`: :class:`~transformers.FlaxRobertaModel` (RoBERTa model) - - contains `bert`: :class:`~transformers.FlaxBertModel` (Bert model) - - The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) To - train the model, you should first set it back in training mode with `model.train()` - - Args: - pretrained_model_name_or_path: either: - - - a string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid - model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or - organization name, like ``dbmdz/bert-base-german-cased``. - - a path to a `directory` containing model weights saved using - :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - - a path or url to a `pytorch index checkpoint file` (e.g. `./pt_model/pytorch_model.bin`). In this - case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` - argument. - - model_args: (`optional`) Sequence of positional arguments: - All remaining positional arguments will be passed to the underlying model's ``__init__`` method - - config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a - pretrained model), or - - the model was saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and is reloaded - by supplying the save directory. - - the model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a - configuration JSON file named `config.json` is found in the directory. - - cache_dir: (`optional`) string: - Path to a directory in which a downloaded pre-trained model configuration should be cached if the - standard cache should not be used. - - force_download: (`optional`) boolean, default False: - Force to (re-)download the model weights and configuration files and override the cached versions if - they exists. - - resume_download: (`optional`) boolean, default False: - Do not delete incompletely received file. Attempt to resume the download if such a file exists. - - proxies: (`optional`) dict, default None: - A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. - - output_loading_info: (`optional`) boolean: - Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error - messages. - - kwargs: (`optional`) Remaining dictionary of keyword arguments: - These arguments will be passed to the configuration and the model. - - Examples:: - - model = FlaxAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from huggingface.co and cache. - model = FlaxAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` - assert model.config.output_attention == True - - """ - config = kwargs.pop("config", None) - if not isinstance(config, PretrainedConfig): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - - for config_class, model_class in FLAX_MODEL_MAPPING.items(): - if isinstance(config, config_class): - return model_class.from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, _from_auto=True, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} " - f"for this kind of FlaxAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}" - ) +FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( + [ + # Model for Question Answering mapping + (BertConfig, FlaxBertForQuestionAnswering), + ] +) + +FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( + [ + # Model for Token Classification mapping + (BertConfig, FlaxBertForTokenClassification), + ] +) + +FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( + [ + # Model for Multiple Choice mapping + (BertConfig, FlaxBertForMultipleChoice), + ] +) + +FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( + [ + (BertConfig, FlaxBertForNextSentencePrediction), + ] +) + +FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING) + +FlaxAutoModelForPreTraining = auto_class_factory( + "FlaxAutoModelForPreTraining", FLAX_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining" +) + +FlaxAutoModelForMaskedLM = auto_class_factory( + "FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling" +) + +FlaxAutoModelForSequenceClassification = auto_class_factory( + "AFlaxutoModelForSequenceClassification", + FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + head_doc="sequence classification", +) + +FlaxAutoModelForQuestionAnswering = auto_class_factory( + "FlaxAutoModelForQuestionAnswering", FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering" +) + +FlaxAutoModelForTokenClassification = auto_class_factory( + "FlaxAutoModelForTokenClassification", FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification" +) + +FlaxAutoModelForMultipleChoice = auto_class_factory( + "AutoModelForMultipleChoice", FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice" +) + +FlaxAutoModelForNextSentencePrediction = auto_class_factory( + "FlaxAutoModelForNextSentencePrediction", + FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + head_doc="next sentence prediction", +) diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 62df0925c72f..0abb08c8902c 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -18,8 +18,6 @@ import warnings from collections import OrderedDict -from ...configuration_utils import PretrainedConfig -from ...file_utils import add_start_docstrings from ...utils import logging # Add modeling imports here @@ -179,9 +177,9 @@ TFXLNetLMHeadModel, TFXLNetModel, ) +from .auto_factory import auto_class_factory from .configuration_auto import ( AlbertConfig, - AutoConfig, BartConfig, BertConfig, BlenderbotConfig, @@ -212,7 +210,6 @@ XLMConfig, XLMRobertaConfig, XLNetConfig, - replace_list_option_in_docstrings, ) @@ -465,1094 +462,74 @@ ) -TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r""" +TFAutoModel = auto_class_factory("TFAutoModel", TF_MODEL_MAPPING) - The model class to instantiate is selected based on the :obj:`model_type` property of the config object (either - passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's missing, - by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`: - - List options - - The model is set in evaluation mode by default using ``model.eval()`` (so for instance, dropout modules are - deactivated). To train the model, you should first set it back in training mode with ``model.train()`` - - Args: - pretrained_model_name_or_path: - Can be either: - - - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under - a user or organization name, like ``dbmdz/bert-base-german-cased``. - - A path to a `directory` containing model weights saved using - :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In - this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided - as ``config`` argument. This loading path is slower than converting the PyTorch model in a - TensorFlow model using the provided conversion scripts and loading the TensorFlow model - afterwards. - model_args (additional positional arguments, `optional`): - Will be passed along to the underlying model ``__init__()`` method. - config (:class:`~transformers.PretrainedConfig`, `optional`): - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - The model is a model provided by the library (loaded with the `model id` string of a pretrained - model). - - The model was saved using :meth:`~transformers.PreTrainedModel.save_pretrained` and is reloaded - by suppyling the save directory. - - The model is loaded by suppyling a local directory as ``pretrained_model_name_or_path`` and a - configuration JSON file named `config.json` is found in the directory. - state_dict (`Dict[str, torch.Tensor]`, `optional`): - A state dictionary to use instead of a state dictionary loaded from saved weights file. - - This option can be used if you want to create a model from a pretrained configuration but load your own - weights. In this case though, you should check if using - :func:`~transformers.PreTrainedModel.save_pretrained` and - :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. - cache_dir (:obj:`str`, `optional`): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): - Load the model weights from a TensorFlow checkpoint save file (see docstring of - ``pretrained_model_name_or_path`` argument). - force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (:obj:`Dict[str, str], `optional`): - A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to only look at local files (e.g., not try downloading the model). - revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any - identifier allowed by git. - kwargs (additional keyword arguments, `optional`): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or - automatically loaded: - - - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the - underlying model's ``__init__`` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class - initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of - ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute - with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration - attribute will be passed to the underlying model's ``__init__`` function. -""" - - -class TFAutoModel(object): - r""" - This is a generic model class that will be instantiated as one of the base model classes of the library when - created with the when created with the :meth:`~transformers.TFAutoModel.from_pretrained` class method or the - :meth:`~transformers.TFAutoModel.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "TFAutoModel is designed to be instantiated " - "using the `TFAutoModel.from_pretrained(pretrained_model_name_or_path)` or " - "`TFAutoModel.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_MAPPING, use_model_types=False) - def from_config(cls, config, **kwargs): - r""" - Instantiates one of the base model classes of the library from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.TFAutoModel.from_pretrained` to load the model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, TFAutoModel - >>> # Download configuration from huggingface.co and cache. - >>> config = TFAutoConfig.from_pretrained('bert-base-uncased') - >>> model = TFAutoModel.from_config(config) - """ - if type(config) in TF_MODEL_MAPPING.keys(): - return TF_MODEL_MAPPING[type(config)](config, **kwargs) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_MAPPING) - @add_start_docstrings( - "Instantiate one of the base model classes of the library from a pretrained model.", - TF_AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - - Examples:: - - >>> from transformers import AutoConfig, AutoModel - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFAutoModel.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = TFAutoModel.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained('./pt_model/bert_pt_model_config.json') - >>> model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in TF_MODEL_MAPPING.keys(): - return TF_MODEL_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_MAPPING.keys())}." - ) - - -class TFAutoModelForPreTraining(object): - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with the - architecture used for pretraining this model---when created with the - :meth:`~transformers.TFAutoModelForPreTraining.from_pretrained` class method or the - :meth:`~transformers.TFAutoModelForPreTraining.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "TFAutoModelForPreTraining is designed to be instantiated " - "using the `TFAutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)` or " - "`TFAutoModelForPreTraining.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_PRETRAINING_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with the architecture used for pretraining this - model---from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.TFAutoModelForPreTraining.from_pretrained` to load the - model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForPreTraining - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = TFAutoModelForPreTraining.from_config(config) - """ - if type(config) in TF_MODEL_FOR_PRETRAINING_MAPPING.keys(): - return TF_MODEL_FOR_PRETRAINING_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_PRETRAINING_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_PRETRAINING_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with the architecture used for pretraining this ", - "model---from a pretrained model.", - TF_AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForPreTraining +TFAutoModelForPreTraining = auto_class_factory( + "TFAutoModelForPreTraining", TF_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining" +) - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFAutoModelForPreTraining.from_pretrained('bert-base-uncased') +# Private on puprose, the public class will add the deprecation warnings. +_TFAutoModelWithLMHead = auto_class_factory( + "TFAutoModelWithLMHead", TF_MODEL_WITH_LM_HEAD_MAPPING, head_doc="language modeling" +) - >>> # Update configuration during loading - >>> model = TFAutoModelForPreTraining.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True +TFAutoModelForCausalLM = auto_class_factory( + "TFAutoModelForCausalLM", TF_MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling" +) - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained('./pt_model/bert_pt_model_config.json') - >>> model = TFAutoModelForPreTraining.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) +TFAutoModelForMaskedLM = auto_class_factory( + "TFAutoModelForMaskedLM", TF_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling" +) - if type(config) in TF_MODEL_FOR_PRETRAINING_MAPPING.keys(): - return TF_MODEL_FOR_PRETRAINING_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_PRETRAINING_MAPPING.keys())}." - ) +TFAutoModelForSeq2SeqLM = auto_class_factory( + "TFAutoModelForSeq2SeqLM", + TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + head_doc="sequence-to-sequence language modeling", + checkpoint_for_example="t5-base", +) +TFAutoModelForSequenceClassification = auto_class_factory( + "TFAutoModelForSequenceClassification", + TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + head_doc="sequence classification", +) -class TFAutoModelWithLMHead(object): - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a - language modeling head---when created with the :meth:`~transformers.TFAutoModelWithLMHead.from_pretrained` class - method or the :meth:`~transformers.TFAutoModelWithLMHead.from_config` class method. +TFAutoModelForQuestionAnswering = auto_class_factory( + "TFAutoModelForQuestionAnswering", TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering" +) - This class cannot be instantiated directly using ``__init__()`` (throws an error). +TFAutoModelForTokenClassification = auto_class_factory( + "TFAutoModelForTokenClassification", TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification" +) - .. warning:: +TFAutoModelForMultipleChoice = auto_class_factory( + "TFAutoModelForMultipleChoice", TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice" +) - This class is deprecated and will be removed in a future version. Please use - :class:`~transformers.TFAutoModelForCausalLM` for causal language models, - :class:`~transformers.TFAutoModelForMaskedLM` for masked language models and - :class:`~transformers.TFAutoModelForSeq2SeqLM` for encoder-decoder models. - """ +TFAutoModelForNextSentencePrediction = auto_class_factory( + "TFAutoModelForNextSentencePrediction", + TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + head_doc="next sentence prediction", +) - def __init__(self): - raise EnvironmentError( - "TFAutoModelWithLMHead is designed to be instantiated " - "using the `TFAutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` or " - "`TFAutoModelWithLMHead.from_config(config)` methods." - ) +class TFAutoModelWithLMHead(_TFAutoModelWithLMHead): @classmethod - @replace_list_option_in_docstrings(TF_MODEL_WITH_LM_HEAD_MAPPING, use_model_types=False) def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a language modeling head---from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.TFAutoModelWithLMHead.from_pretrained` to load the model - weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelWithLMHead - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = TFAutoModelWithLMHead.from_config(config) - """ warnings.warn( "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " - "`TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models " - "and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.", + "`TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models and " + "`TFAutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) - if type(config) in TF_MODEL_WITH_LM_HEAD_MAPPING.keys(): - return TF_MODEL_WITH_LM_HEAD_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_WITH_LM_HEAD_MAPPING.keys())}." - ) + return super().from_config(config) @classmethod - @replace_list_option_in_docstrings(TF_MODEL_WITH_LM_HEAD_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a language modeling head---from a pretrained ", - "model.", - TF_AUTO_MODEL_PRETRAINED_DOCSTRING, - ) def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelWithLMHead - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFAutoModelWithLMHead.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = TFAutoModelWithLMHead.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained('./pt_model/bert_pt_model_config.json') - >>> model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) - """ warnings.warn( "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " - "`TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models " - "and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.", + "`TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models and " + "`TFAutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in TF_MODEL_WITH_LM_HEAD_MAPPING.keys(): - return TF_MODEL_WITH_LM_HEAD_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_WITH_LM_HEAD_MAPPING.keys())}." - ) - - -class TFAutoModelForCausalLM: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a causal - language modeling head---when created with the :meth:`~transformers.TFAutoModelForCausalLM.from_pretrained` class - method or the :meth:`~transformers.TFAutoModelForCausalLM.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "TFAutoModelForCausalLM is designed to be instantiated " - "using the `TFAutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` or " - "`TFAutoModelForCausalLM.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_CAUSAL_LM_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a causal language modeling head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.TFAutoModelForCausalLM.from_pretrained` to load the model - weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForCausalLM - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('gpt2') - >>> model = TFAutoModelForCausalLM.from_config(config) - """ - if type(config) in TF_MODEL_FOR_CAUSAL_LM_MAPPING.keys(): - return TF_MODEL_FOR_CAUSAL_LM_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_CAUSAL_LM_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_CAUSAL_LM_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a causal language modeling head---from a " - "pretrained model.", - TF_AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForCausalLM - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFAutoModelForCausalLM.from_pretrained('gpt2') - - >>> # Update configuration during loading - >>> model = TFAutoModelForCausalLM.from_pretrained('gpt2', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained('./pt_model/gpt2_pt_model_config.json') - >>> model = TFAutoModelForCausalLM.from_pretrained('./pt_model/gpt2_pytorch_model.bin', from_pt=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in TF_MODEL_FOR_CAUSAL_LM_MAPPING.keys(): - return TF_MODEL_FOR_CAUSAL_LM_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_CAUSAL_LM_MAPPING.keys())}." - ) - - -class TFAutoModelForMaskedLM: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a masked - language modeling head---when created with the :meth:`~transformers.TFAutoModelForMaskedLM.from_pretrained` class - method or the :meth:`~transformers.TFAutoModelForMaskedLM.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "TFAutoModelForMaskedLM is designed to be instantiated " - "using the `TFAutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)` or " - "`TFAutoModelForMaskedLM.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_MASKED_LM_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a masked language modeling head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.TFAutoModelForMaskedLM.from_pretrained` to load the model - weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForMaskedLM - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = TFAutoModelForMaskedLM.from_config(config) - """ - if type(config) in TF_MODEL_FOR_MASKED_LM_MAPPING.keys(): - return TF_MODEL_FOR_MASKED_LM_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_MASKED_LM_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_MASKED_LM_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a masked language modeling head---from a " - "pretrained model.", - TF_AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForMaskedLM - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFAutoModelForMaskedLM.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = TFAutoModelForMaskedLM.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained('./pt_model/bert_pt_model_config.json') - >>> model = TFAutoModelForMaskedLM.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in TF_MODEL_FOR_MASKED_LM_MAPPING.keys(): - return TF_MODEL_FOR_MASKED_LM_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_MASKED_LM_MAPPING.keys())}." - ) - - -class TFAutoModelForSeq2SeqLM: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a - sequence-to-sequence language modeling head---when created with the - :meth:`~transformers.TFAutoModelForSeq2SeqLM.from_pretrained` class method or the - :meth:`~transformers.TFAutoModelForSeq2SeqLM.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "TFAutoModelForSeq2SeqLM is designed to be instantiated " - "using the `TFAutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)` or " - "`TFAutoModelForSeq2SeqLM.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, use_model_types=False) - def from_config(cls, config, **kwargs): - r""" - Instantiates one of the model classes of the library---with a sequence-to-sequence language modeling - head---from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.TFAutoModelForSeq2SeqLM.from_pretrained` to load the model - weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForSeq2SeqLM - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('t5') - >>> model = TFAutoModelForSeq2SeqLM.from_config(config) - """ - if type(config) in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys(): - return TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)](config, **kwargs) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, use_model_types=False) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a sequence-to-sequence language modeling " - "head---from a pretrained model.", - TF_AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForSeq2SeqLM - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFAutoModelForSeq2SeqLM.from_pretrained('t5-base') - - >>> # Update configuration during loading - >>> model = TFAutoModelForSeq2SeqLM.from_pretrained('t5-base', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained('./pt_model/t5_pt_model_config.json') - >>> model = TFAutoModelForSeq2SeqLM.from_pretrained('./pt_model/t5_pytorch_model.bin', from_pt=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys(): - return TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())}." - ) - - -class TFAutoModelForSequenceClassification(object): - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a - sequence classification head---when created with the - :meth:`~transformers.TFAutoModelForSequenceClassification.from_pretrained` class method or the - :meth:`~transformers.TFAutoModelForSequenceClassification.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "TFAutoModelForSequenceClassification is designed to be instantiated " - "using the `TFAutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` or " - "`TFAutoModelForSequenceClassification.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a sequence classification head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.TFAutoModelForSequenceClassification.from_pretrained` to - load the model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForSequenceClassification - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = TFAutoModelForSequenceClassification.from_config(config) - """ - if type(config) in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys(): - return TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a sequence classification head---from a " - "pretrained model.", - TF_AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForSequenceClassification - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained('./pt_model/bert_pt_model_config.json') - >>> model = TFAutoModelForSequenceClassification.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys(): - return TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys())}." - ) - - -class TFAutoModelForQuestionAnswering(object): - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a - question answering head---when created with the - :meth:`~transformers.TFAutoModeForQuestionAnswering.from_pretrained` class method or the - :meth:`~transformers.TFAutoModelForQuestionAnswering.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "TFAutoModelForQuestionAnswering is designed to be instantiated " - "using the `TFAutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` or " - "`TFAutoModelForQuestionAnswering.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a question answering head---from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.TFAutoModelForQuestionAnswering.from_pretrained` to load - the model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForQuestionAnswering - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = TFAutoModelForQuestionAnswering.from_config(config) - """ - if type(config) in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys(): - return TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a question answering head---from a " - "pretrained model.", - TF_AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForQuestionAnswering - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFAutoModelForQuestionAnswering.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = TFAutoModelForQuestionAnswering.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained('./pt_model/bert_pt_model_config.json') - >>> model = TFAutoModelForQuestionAnswering.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys(): - return TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())}." - ) - - -class TFAutoModelForTokenClassification: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a token - classification head---when created with the :meth:`~transformers.TFAutoModelForTokenClassification.from_pretrained` - class method or the :meth:`~transformers.TFAutoModelForTokenClassification.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "TFAutoModelForTokenClassification is designed to be instantiated " - "using the `TFAutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` or " - "`TFAutoModelForTokenClassification.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a token classification head---from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.TFAutoModelForTokenClassification.from_pretrained` to load - the model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForTokenClassification - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = TFAutoModelForTokenClassification.from_config(config) - """ - if type(config) in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys(): - return TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a token classification head---from a " - "pretrained model.", - TF_AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForTokenClassification - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFAutoModelForTokenClassification.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = TFAutoModelForTokenClassification.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained('./pt_model/bert_pt_model_config.json') - >>> model = TFAutoModelForTokenClassification.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys(): - return TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())}." - ) - - -class TFAutoModelForMultipleChoice: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a - multiple choice classification head---when created with the - :meth:`~transformers.TFAutoModelForMultipleChoice.from_pretrained` class method or the - :meth:`~transformers.TFAutoModelForMultipleChoice.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "TFAutoModelForMultipleChoice is designed to be instantiated " - "using the `TFAutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)` or " - "`TFAutoModelForMultipleChoice.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a multiple choice classification head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.TFAutoModelForMultipleChoice.from_pretrained` to load the - model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForMultipleChoice - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = TFAutoModelForMultipleChoice.from_config(config) - """ - if type(config) in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys(): - return TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a multiple choice classification head---from a " - "pretrained model.", - TF_AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForMultipleChoice - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFAutoModelForMultipleChoice.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = TFAutoModelForMultipleChoice.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained('./pt_model/bert_pt_model_config.json') - >>> model = TFAutoModelForMultipleChoice.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys(): - return TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys())}." - ) - - -class TFAutoModelForNextSentencePrediction: - r""" - This is a generic model class that will be instantiated as one of the model classes of the library---with a next - sentence prediction head---when created with the - :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` class method or the - :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_config` class method. - - This class cannot be instantiated directly using ``__init__()`` (throws an error). - """ - - def __init__(self): - raise EnvironmentError( - "TFAutoModelForNextSentencePrediction is designed to be instantiated " - "using the `TFAutoModelForNextSentencePrediction.from_pretrained(pretrained_model_name_or_path)` or " - "`TFAutoModelForNextSentencePrediction.from_config(config)` methods." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, use_model_types=False) - def from_config(cls, config): - r""" - Instantiates one of the model classes of the library---with a next sentence prediction head---from a - configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use :meth:`~transformers.TFAutoModelForNextSentencePrediction.from_pretrained` to - load the model weights. - - Args: - config (:class:`~transformers.PretrainedConfig`): - The model class to instantiate is selected based on the configuration class: - - List options - - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained('bert-base-uncased') - >>> model = TFAutoModelForNextSentencePrediction.from_config(config) - """ - if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): - return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)](config) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys())}." - ) - - @classmethod - @replace_list_option_in_docstrings(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING) - @add_start_docstrings( - "Instantiate one of the model classes of the library---with a next sentence prediction head---from a " - "pretrained model.", - TF_AUTO_MODEL_PRETRAINED_DOCSTRING, - ) - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" - Examples:: - - >>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased') - - >>> # Update configuration during loading - >>> model = TFAutoModelForNextSentencePrediction.from_pretrained('bert-base-uncased', output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained('./pt_model/bert_pt_model_config.json') - >>> model = TFAutoModelForNextSentencePrediction.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) - """ - config = kwargs.pop("config", None) - kwargs["_from_auto"] = True - if not isinstance(config, PretrainedConfig): - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs - ) - - if type(config) in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys(): - return TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING[type(config)].from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of TFAutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys())}." - ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index deea31820fbc..8649d1c5e53f 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -11,6 +11,27 @@ def from_pretrained(self, *args, **kwargs): requires_flax(self) +FLAX_MODEL_FOR_MASKED_LM_MAPPING = None + + +FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = None + + +FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = None + + +FLAX_MODEL_FOR_PRETRAINING_MAPPING = None + + +FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None + + +FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None + + +FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None + + FLAX_MODEL_MAPPING = None @@ -23,6 +44,69 @@ def from_pretrained(self, *args, **kwargs): requires_flax(self) +class FlaxAutoModelForMaskedLM: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + +class FlaxAutoModelForMultipleChoice: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + +class FlaxAutoModelForNextSentencePrediction: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + +class FlaxAutoModelForPreTraining: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + +class FlaxAutoModelForQuestionAnswering: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + +class FlaxAutoModelForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + +class FlaxAutoModelForTokenClassification: + def __init__(self, *args, **kwargs): + requires_flax(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_flax(self) + + class FlaxBertForMaskedLM: def __init__(self, *args, **kwargs): requires_flax(self)