From eb764802d8accd950bb9bf4f0355f6ac1524024d Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 4 Jan 2023 17:15:05 -0800 Subject: [PATCH] Create `Backbone` base class (#621) * Create `Backbone` base class * Format * Throw exception if no presets. * Address comments 1 * Address comments 2 * Extend to RoBERTa * Format * Attach `from_preset` to subclass and extend to all models * format * Fix typos in docstring * More docstring typos --- keras_nlp/models/backbone.py | 89 +++++++++++++++++++ keras_nlp/models/bert/bert_backbone.py | 81 ++++------------- .../models/deberta_v3/deberta_v3_backbone.py | 82 ++++------------- .../distil_bert/distil_bert_backbone.py | 83 ++++------------- keras_nlp/models/gpt2/gpt2_backbone.py | 78 ++++------------ keras_nlp/models/roberta/roberta_backbone.py | 81 ++++------------- .../xlm_roberta/xlm_roberta_backbone.py | 81 +++++------------ 7 files changed, 196 insertions(+), 379 deletions(-) create mode 100644 keras_nlp/models/backbone.py diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py new file mode 100644 index 0000000000..510235a949 --- /dev/null +++ b/keras_nlp/models/backbone.py @@ -0,0 +1,89 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for Backbone models.""" + +import os + +from tensorflow import keras + +from keras_nlp.utils.python_utils import classproperty + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class Backbone(keras.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def from_config(cls, config): + return cls(**config) + + @classproperty + def presets(cls): + return {} + + @classmethod + def from_preset( + cls, + preset, + load_weights=True, + **kwargs, + ): + """Instantiate {{model_name}} model from preset architecture and weights. + + Args: + preset: string. Must be one of "{{preset_names}}". + load_weights: Whether to load pre-trained weights into model. + Defaults to `True`. + + Examples: + ```python + # Load architecture and weights from preset + model = {{model_name}}.from_preset("{{example_preset_name}}") + + # Load randomly initialized model from preset architecture + model = {{model_name}}.from_preset( + "{{example_preset_name}}", + load_weights=False + ) + ``` + """ + + if not cls.presets: + raise NotImplementedError( + "No presets have been created for this class." + ) + + if preset not in cls.presets: + raise ValueError( + "`preset` must be one of " + f"""{", ".join(cls.presets)}. Received: {preset}.""" + ) + metadata = cls.presets[preset] + config = metadata["config"] + model = cls.from_config({**config, **kwargs}) + + if not load_weights: + return model + + weights = keras.utils.get_file( + "model.h5", + metadata["weights_url"], + cache_subdir=os.path.join("models", preset), + file_hash=metadata["weights_hash"], + ) + + model.load_weights(weights) + return model diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index 2659b7cbdc..8339d0f893 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""BERT backbone models.""" +"""BERT backbone model.""" import copy -import os import tensorflow as tf from tensorflow import keras from keras_nlp.layers.position_embedding import PositionEmbedding from keras_nlp.layers.transformer_encoder import TransformerEncoder +from keras_nlp.models.backbone import Backbone from keras_nlp.models.bert.bert_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.python_utils import format_docstring @@ -32,7 +32,7 @@ def bert_kernel_initializer(stddev=0.02): @keras.utils.register_keras_serializable(package="keras_nlp") -class BertBackbone(keras.Model): +class BertBackbone(Backbone): """BERT encoder network. This class implements a bi-directional Transformer-based encoder as @@ -76,7 +76,11 @@ class BertBackbone(keras.Model): ), } - # Randomly initialized BERT encoder + # Pretrained BERT encoder + model = keras_nlp.models.BertBackbone.from_preset("base_base_en_uncased") + output = model(input_data) + + # Randomly initialized BERT encoder with a custom config model = keras_nlp.models.BertBackbone( vocabulary_size=30552, num_layers=12, @@ -212,71 +216,18 @@ def get_config(self): "trainable": self.trainable, } - @classmethod - def from_config(cls, config): - return cls(**config) - @classproperty def presets(cls): return copy.deepcopy(backbone_presets) @classmethod - @format_docstring(names=", ".join(backbone_presets)) - def from_preset( - cls, - preset, - load_weights=True, - **kwargs, - ): - """Instantiate BERT model from preset architecture and weights. - - Args: - preset: string. Must be one of {{names}}. - load_weights: Whether to load pre-trained weights into model. - Defaults to `True`. + def from_preset(cls, preset, load_weights=True, **kwargs): + return super().from_preset(preset, load_weights, **kwargs) - Examples: - ```python - input_data = { - "token_ids": tf.ones(shape=(1, 12), dtype=tf.int64), - "segment_ids": tf.constant( - [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) - ), - "padding_mask": tf.constant( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) - ), - } - - # Load architecture and weights from preset - model = BertBackbone.from_preset("bert_base_en_uncased") - output = model(input_data) - - # Load randomly initialized model from preset architecture - model = BertBackbone.from_preset( - "bert_base_en_uncased", - load_weights=False - ) - output = model(input_data) - ``` - """ - if preset not in cls.presets: - raise ValueError( - "`preset` must be one of " - f"""{", ".join(cls.presets)}. Received: {preset}.""" - ) - metadata = cls.presets[preset] - config = metadata["config"] - model = cls.from_config({**config, **kwargs}) - - if not load_weights: - return model - - weights = keras.utils.get_file( - "model.h5", - metadata["weights_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["weights_hash"], - ) - model.load_weights(weights) - return model +BertBackbone.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__ +format_docstring( + model_name=BertBackbone.__name__, + example_preset_name="bert_base_en_uncased", + preset_names='", "'.join(BertBackbone.presets), +)(BertBackbone.from_preset.__func__) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py index 6438691c56..585a15a57b 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py @@ -15,11 +15,11 @@ """DeBERTa backbone model.""" import copy -import os import tensorflow as tf from tensorflow import keras +from keras_nlp.models.backbone import Backbone from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets from keras_nlp.models.deberta_v3.disentangled_attention_encoder import ( DisentangledAttentionEncoder, @@ -34,7 +34,7 @@ def deberta_kernel_initializer(stddev=0.02): @keras.utils.register_keras_serializable(package="keras_nlp") -class DebertaV3Backbone(keras.Model): +class DebertaV3Backbone(Backbone): """DeBERTa encoder network. This network implements a bi-directional Transformer-based encoder as @@ -45,7 +45,7 @@ class DebertaV3Backbone(keras.Model): The default constructor gives a fully customizable, randomly initialized DeBERTa encoder with any number of layers, heads, and embedding - dimensions. To load preset architectures and weights, use the `from_presets` + dimensions. To load preset architectures and weights, use the `from_preset` constructor. Disclaimer: Pre-trained models are provided on an "as is" basis, without @@ -76,7 +76,13 @@ class DebertaV3Backbone(keras.Model): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)), } - # Randomly initialized DeBERTa model + # Pretrained DeBERTa encoder + model = keras_nlp.models.DebertaV3Backbone.from_preset( + "deberta_base_en", + ) + output = model(input_data) + + # Randomly initialized DeBERTa encoder with custom config model = keras_nlp.models.DebertaV3Backbone( vocabulary_size=128100, num_layers=12, @@ -86,7 +92,6 @@ class DebertaV3Backbone(keras.Model): max_sequence_length=512, bucket_size=256, ) - # Call the model on the input data. output = model(input_data) ``` @@ -194,69 +199,18 @@ def get_config(self): "trainable": self.trainable, } - @classmethod - def from_config(cls, config): - return cls(**config) - @classproperty def presets(cls): return copy.deepcopy(backbone_presets) @classmethod - @format_docstring(names=", ".join(backbone_presets)) - def from_preset( - cls, - preset, - load_weights=True, - **kwargs, - ): - """Instantiate DeBERTa model from preset architecture and weights. - - Args: - preset: string. Must be one of {{names}}. - load_weights: Whether to load pre-trained weights into model. - Defaults to `True`. + def from_preset(cls, preset, load_weights=True, **kwargs): + return super().from_preset(preset, load_weights, **kwargs) - Examples: - ```python - input_data = { - "token_ids": tf.ones(shape=(1, 12), dtype=tf.int64), - "padding_mask": tf.constant( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) - ), - } - - # Load architecture and weights from preset - model = keras_nlp.models.DebertaV3Backbone.from_preset( - "deberta_base_en", - ) - output = model(input_data) - - # Load randomly initialized model from preset architecture - model = keras_nlp.models.DebertaV3Backbone.from_preset( - "deberta_base_en", load_weights=False - ) - output = model(input_data) - ``` - """ - if preset not in cls.presets: - raise ValueError( - "`preset` must be one of " - f"""{", ".join(cls.presets)}. Received: {preset}.""" - ) - metadata = cls.presets[preset] - config = metadata["config"] - model = cls.from_config({**config, **kwargs}) - - if not load_weights: - return model - - weights = keras.utils.get_file( - "model.h5", - metadata["weights_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["weights_hash"], - ) - model.load_weights(weights) - return model +DebertaV3Backbone.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__ +format_docstring( + model_name=DebertaV3Backbone.__name__, + example_preset_name="deberta_base_en", + preset_names='", "'.join(DebertaV3Backbone.presets), +)(DebertaV3Backbone.from_preset.__func__) diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone.py b/keras_nlp/models/distil_bert/distil_bert_backbone.py index 3c32589d3c..320010fd6a 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""DistilBERT backbone models.""" +"""DistilBERT backbone model.""" import copy -import os import tensorflow as tf from tensorflow import keras @@ -24,6 +23,7 @@ TokenAndPositionEmbedding, ) from keras_nlp.layers.transformer_encoder import TransformerEncoder +from keras_nlp.models.backbone import Backbone from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.python_utils import format_docstring @@ -34,7 +34,7 @@ def distilbert_kernel_initializer(stddev=0.02): @keras.utils.register_keras_serializable(package="keras_nlp") -class DistilBertBackbone(keras.Model): +class DistilBertBackbone(Backbone): """DistilBERT encoder network. This network implements a bi-directional Transformer-based encoder as @@ -45,7 +45,7 @@ class DistilBertBackbone(keras.Model): The default constructor gives a fully customizable, randomly initialized DistilBERT encoder with any number of layers, heads, and embedding - dimensions. To load preset architectures and weights, use the `from_presets` + dimensions. To load preset architectures and weights, use the `from_preset` constructor. Disclaimer: Pre-trained models are provided on an "as is" basis, without @@ -76,7 +76,13 @@ class DistilBertBackbone(keras.Model): ), } - # Randomly initialized DistilBERT encoder + # Pretrained DistilBERT encoder + model = keras_nlp.models.DistilBertBackbone.from_preset( + "distil_bert_base_en_uncased" + ) + output = model(input_data) + + # Randomly initialized DistilBERT encoder with custom config model = keras_nlp.models.DistilBertBackbone( vocabulary_size=30552, num_layers=6, @@ -173,69 +179,18 @@ def get_config(self): "trainable": self.trainable, } - @classmethod - def from_config(cls, config): - return cls(**config) - @classproperty def presets(cls): return copy.deepcopy(backbone_presets) @classmethod - @format_docstring(names=", ".join(backbone_presets)) - def from_preset( - cls, - preset, - load_weights=True, - **kwargs, - ): - """Instantiate a DistilBERT model from preset architecture and weights. - - Args: - preset: string. Must be one of {{names}}. - load_weights: Whether to load pre-trained weights into model. - Defaults to `True`. - - Examples: - ```python - input_data = { - "token_ids": tf.ones(shape=(1, 12), dtype=tf.int64), - "padding_mask": tf.constant( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) - ), - } - - # Load architecture and weights from preset - model = keras_nlp.models.DistilBertBackbone.from_preset( - "distil_bert_base_en_uncased" - ) - output = model(input_data) + def from_preset(cls, preset, load_weights=True, **kwargs): + return super().from_preset(preset, load_weights, **kwargs) - # Load randomly initialized model from preset architecture - model = keras_nlp.models.DistilBertBackbone.from_preset( - "distil_bert_base_en_uncased", load_weights=False - ) - output = model(input_data) - ``` - """ - if preset not in cls.presets: - raise ValueError( - "`preset` must be one of " - f"""{", ".join(cls.presets)}. Received: {preset}.""" - ) - metadata = cls.presets[preset] - config = metadata["config"] - model = cls.from_config({**config, **kwargs}) - - if not load_weights: - return model - - weights = keras.utils.get_file( - "model.h5", - metadata["weights_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["weights_hash"], - ) - model.load_weights(weights) - return model +DistilBertBackbone.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__ +format_docstring( + model_name=DistilBertBackbone.__name__, + example_preset_name="distil_bert_base_en_uncased", + preset_names='", "'.join(DistilBertBackbone.presets), +)(DistilBertBackbone.from_preset.__func__) diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index 8738364319..36ec03c538 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""GPT-2 backbone models.""" +"""GPT-2 backbone model.""" import copy -import os import tensorflow as tf from tensorflow import keras from keras_nlp.layers import PositionEmbedding from keras_nlp.layers import TransformerDecoder +from keras_nlp.models.backbone import Backbone from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.python_utils import format_docstring @@ -32,7 +32,7 @@ def _gpt_2_kernel_initializer(stddev=0.02): @keras.utils.register_keras_serializable(package="keras_nlp") -class GPT2Backbone(keras.Model): +class GPT2Backbone(Backbone): """GPT-2 core network with hyperparameters. This network implements a Transformer-based decoder network, @@ -42,7 +42,7 @@ class GPT2Backbone(keras.Model): The default constructor gives a fully customizable, randomly initialized GPT-2 model with any number of layers, heads, and embedding - dimensions. To load preset architectures and weights, use the `from_presets` + dimensions. To load preset architectures and weights, use the `from_preset` constructor. Disclaimer: Pre-trained models are provided on an "as is" basis, without @@ -73,7 +73,11 @@ class GPT2Backbone(keras.Model): ), } - # Randomly initialized GPT-2 decoder + # Pretrained GPT-2 decoder + model = GPT2Backbone.from_preset("gpt2_base_en") + output = model(input_data) + + # Randomly initialized GPT-2 decoder with custom config model = keras_nlp.models.GPT2Backbone( vocabulary_size=50257, num_layers=12, @@ -182,66 +186,18 @@ def get_config(self): "trainable": self.trainable, } - @classmethod - def from_config(cls, config): - return cls(**config) - @classproperty def presets(cls): return copy.deepcopy(backbone_presets) @classmethod - @format_docstring(names=", ".join(backbone_presets)) - def from_preset( - cls, - preset, - load_weights=True, - **kwargs, - ): - """Instantiate GPT-2 model from preset architecture and weights. - - Args: - preset: string. Must be one of {{names}}. - load_weights: Whether to load pre-trained weights into model. - Defaults to `True`. - - Examples: - ```python - input_data = { - "token_ids": tf.ones(shape=(1, 12), dtype=tf.int64), - "padding_mask": tf.constant( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) - ), - } + def from_preset(cls, preset, load_weights=True, **kwargs): + return super().from_preset(preset, load_weights, **kwargs) - # Load architecture and weights from preset - model = GPT2Backbone.from_preset("gpt2_base_en") - output = model(input_data) - - # Load randomly initialized model from preset architecture - model = GPT2Backbone.from_preset("gpt2_base", load_weights=False) - output = model(input_data) - ``` - """ - - if preset not in cls.presets: - raise ValueError( - "`preset` must be one of " - f"""{", ".join(cls.presets)}. Received: {preset}.""" - ) - metadata = cls.presets[preset] - config = metadata["config"] - model = cls.from_config({**config, **kwargs}) - - if not load_weights: - return model - - weights = keras.utils.get_file( - "model.h5", - metadata["weights_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["weights_hash"], - ) - model.load_weights(weights) - return model +GPT2Backbone.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__ +format_docstring( + model_name=GPT2Backbone.__name__, + example_preset_name="gpt2_base_en", + preset_names='", "'.join(GPT2Backbone.presets), +)(GPT2Backbone.from_preset.__func__) diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index 99a322d0d2..28a823c9a4 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""RoBERTa backbone models.""" +"""RoBERTa backbone model.""" import copy -import os import tensorflow as tf from tensorflow import keras from keras_nlp.layers import TokenAndPositionEmbedding from keras_nlp.layers import TransformerEncoder +from keras_nlp.models.backbone import Backbone from keras_nlp.models.roberta.roberta_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.python_utils import format_docstring @@ -32,7 +32,7 @@ def roberta_kernel_initializer(stddev=0.02): @keras.utils.register_keras_serializable(package="keras_nlp") -class RobertaBackbone(keras.Model): +class RobertaBackbone(Backbone): """RoBERTa encoder. This network implements a bi-directional Transformer-based encoder as @@ -42,7 +42,7 @@ class RobertaBackbone(keras.Model): The default constructor gives a fully customizable, randomly initialized RoBERTa encoder with any number of layers, heads, and embedding - dimensions. To load preset architectures and weights, use the `from_presets` + dimensions. To load preset architectures and weights, use the `from_preset` constructor. Disclaimer: Pre-trained models are provided on an "as is" basis, without @@ -71,7 +71,11 @@ class RobertaBackbone(keras.Model): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)), } - # Randomly initialized RoBERTa model + # Pretrained RoBERTa encoder + model = keras_nlp.models.RobertaBackbone.from_preset("roberta_base_en") + output = model(input_data) + + # Randomly initialized RoBERTa model with custom config model = keras_nlp.models.RobertaBackbone( vocabulary_size=50265, num_layers=12, @@ -80,8 +84,6 @@ class RobertaBackbone(keras.Model): intermediate_dim=3072, max_sequence_length=12 ) - - # Call the model on the input data. output = model(input_data) ``` """ @@ -171,67 +173,18 @@ def get_config(self): "trainable": self.trainable, } - @classmethod - def from_config(cls, config): - return cls(**config) - @classproperty def presets(cls): return copy.deepcopy(backbone_presets) @classmethod - @format_docstring(names=", ".join(backbone_presets)) - def from_preset( - cls, - preset, - load_weights=True, - **kwargs, - ): - """Instantiate RoBERTa model from preset architecture and weights. - - Args: - preset: string. Must be one of {{names}}. - load_weights: Whether to load pre-trained weights into model. - Defaults to `True`. - - Examples: - ```python - input_data = { - "token_ids": tf.ones(shape=(1, 12), dtype=tf.int64), - "padding_mask": tf.constant( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) - ), - } - - # Load architecture and weights from preset - model = keras_nlp.models.RobertaBackbone.from_preset("roberta_base_en") - output = model(input_data) + def from_preset(cls, preset, load_weights=True, **kwargs): + return super().from_preset(preset, load_weights, **kwargs) - # Load randomly initialized model from preset architecture - model = keras_nlp.models.RobertaBackbone.from_preset( - "roberta_base_en", load_weights=False - ) - output = model(input_data) - ``` - """ - if preset not in cls.presets: - raise ValueError( - "`preset` must be one of " - f"""{", ".join(cls.presets)}. Received: {preset}.""" - ) - metadata = cls.presets[preset] - config = metadata["config"] - model = cls.from_config({**config, **kwargs}) - - if not load_weights: - return model - - weights = keras.utils.get_file( - "model.h5", - metadata["weights_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["weights_hash"], - ) - model.load_weights(weights) - return model +RobertaBackbone.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__ +format_docstring( + model_name=RobertaBackbone.__name__, + example_preset_name="roberta_base_en", + preset_names='", "'.join(RobertaBackbone.presets), +)(RobertaBackbone.from_preset.__func__) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py index 820e5bf10d..741ed02ec1 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""XLM-RoBERTa backbone models.""" +"""XLM-RoBERTa backbone model.""" import copy -import os from tensorflow import keras +from keras_nlp.models.backbone import Backbone from keras_nlp.models.roberta import roberta_backbone from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty @@ -37,7 +37,7 @@ class XLMRobertaBackbone(roberta_backbone.RobertaBackbone): The default constructor gives a fully customizable, randomly initialized RoBERTa encoder with any number of layers, heads, and embedding - dimensions. To load preset architectures and weights, use the `from_presets` + dimensions. To load preset architectures and weights, use the `from_preset` constructor. Disclaimer: Pre-trained models are provided on an "as is" basis, without @@ -66,7 +66,13 @@ class XLMRobertaBackbone(roberta_backbone.RobertaBackbone): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)), } - # Randomly initialized XLM-R model + # Pretrained XLM-R encoder + model = keras_nlp.models.XLMRobertaBackbone.from_preset( + "xlm_roberta_base_multi", + ) + output = model(input_data) + + # Randomly initialized XLM-R model with custom config model = keras_nlp.models.XLMRobertaBackbone( vocabulary_size=250002, num_layers=12, @@ -86,60 +92,13 @@ def presets(cls): return copy.deepcopy(backbone_presets) @classmethod - @format_docstring(names=", ".join(backbone_presets)) - def from_preset( - cls, - preset, - load_weights=True, - **kwargs, - ): - """Instantiate XLM-RoBERTa model from preset architecture and weights. - - Args: - preset: string. Must be one of {{names}}. - load_weights: Whether to load pre-trained weights into model. - Defaults to `True`. - - Examples: - ```python - input_data = { - "token_ids": tf.ones(shape=(1, 12), dtype=tf.int64), - "padding_mask": tf.constant( - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) - ), - } - - # Load architecture and weights from preset - model = keras_nlp.models.XLMRobertaBackbone.from_preset( - "xlm_roberta_base_multi", - ) - output = model(input_data) - - # Load randomly initialized model from preset architecture - model = keras_nlp.models.XLMRobertaBackbone.from_preset( - "xlm_roberta_base_multi", load_weights=False - ) - output = model(input_data) - ``` - """ - if preset not in cls.presets: - raise ValueError( - "`preset` must be one of " - f"""{", ".join(cls.presets)}. Received: {preset}.""" - ) - metadata = cls.presets[preset] - config = metadata["config"] - model = cls.from_config({**config, **kwargs}) - - if not load_weights: - return model - - weights = keras.utils.get_file( - "model.h5", - metadata["weights_url"], - cache_subdir=os.path.join("models", preset), - file_hash=metadata["weights_hash"], - ) - - model.load_weights(weights) - return model + def from_preset(cls, preset, load_weights=True, **kwargs): + return super().from_preset(preset, load_weights, **kwargs) + + +XLMRobertaBackbone.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__ +format_docstring( + model_name=XLMRobertaBackbone.__name__, + example_preset_name="xlm_roberta_base_multi", + preset_names=", ".join(XLMRobertaBackbone.presets), +)(XLMRobertaBackbone.from_preset.__func__)