Skip to content

Commit

Permalink
Create Backbone base class (#621)
Browse files Browse the repository at this point in the history
* Create `Backbone` base class

* Format

* Throw exception if no presets.

* Address comments 1

* Address comments 2

* Extend to RoBERTa

* Format

* Attach `from_preset` to subclass and extend to all models

* format

* Fix typos in docstring

* More docstring typos
  • Loading branch information
jbischof authored Jan 5, 2023
1 parent 802c7ef commit eb76480
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 379 deletions.
89 changes: 89 additions & 0 deletions keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base class for Backbone models."""

import os

from tensorflow import keras

from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class Backbone(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def from_config(cls, config):
return cls(**config)

@classproperty
def presets(cls):
return {}

@classmethod
def from_preset(
cls,
preset,
load_weights=True,
**kwargs,
):
"""Instantiate {{model_name}} model from preset architecture and weights.
Args:
preset: string. Must be one of "{{preset_names}}".
load_weights: Whether to load pre-trained weights into model.
Defaults to `True`.
Examples:
```python
# Load architecture and weights from preset
model = {{model_name}}.from_preset("{{example_preset_name}}")
# Load randomly initialized model from preset architecture
model = {{model_name}}.from_preset(
"{{example_preset_name}}",
load_weights=False
)
```
"""

if not cls.presets:
raise NotImplementedError(
"No presets have been created for this class."
)

if preset not in cls.presets:
raise ValueError(
"`preset` must be one of "
f"""{", ".join(cls.presets)}. Received: {preset}."""
)
metadata = cls.presets[preset]
config = metadata["config"]
model = cls.from_config({**config, **kwargs})

if not load_weights:
return model

weights = keras.utils.get_file(
"model.h5",
metadata["weights_url"],
cache_subdir=os.path.join("models", preset),
file_hash=metadata["weights_hash"],
)

model.load_weights(weights)
return model
81 changes: 16 additions & 65 deletions keras_nlp/models/bert/bert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""BERT backbone models."""
"""BERT backbone model."""

import copy
import os

import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.layers.transformer_encoder import TransformerEncoder
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.bert.bert_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring
Expand All @@ -32,7 +32,7 @@ def bert_kernel_initializer(stddev=0.02):


@keras.utils.register_keras_serializable(package="keras_nlp")
class BertBackbone(keras.Model):
class BertBackbone(Backbone):
"""BERT encoder network.
This class implements a bi-directional Transformer-based encoder as
Expand Down Expand Up @@ -76,7 +76,11 @@ class BertBackbone(keras.Model):
),
}
# Randomly initialized BERT encoder
# Pretrained BERT encoder
model = keras_nlp.models.BertBackbone.from_preset("base_base_en_uncased")
output = model(input_data)
# Randomly initialized BERT encoder with a custom config
model = keras_nlp.models.BertBackbone(
vocabulary_size=30552,
num_layers=12,
Expand Down Expand Up @@ -212,71 +216,18 @@ def get_config(self):
"trainable": self.trainable,
}

@classmethod
def from_config(cls, config):
return cls(**config)

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

@classmethod
@format_docstring(names=", ".join(backbone_presets))
def from_preset(
cls,
preset,
load_weights=True,
**kwargs,
):
"""Instantiate BERT model from preset architecture and weights.
Args:
preset: string. Must be one of {{names}}.
load_weights: Whether to load pre-trained weights into model.
Defaults to `True`.
def from_preset(cls, preset, load_weights=True, **kwargs):
return super().from_preset(preset, load_weights, **kwargs)

Examples:
```python
input_data = {
"token_ids": tf.ones(shape=(1, 12), dtype=tf.int64),
"segment_ids": tf.constant(
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
),
"padding_mask": tf.constant(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
),
}
# Load architecture and weights from preset
model = BertBackbone.from_preset("bert_base_en_uncased")
output = model(input_data)
# Load randomly initialized model from preset architecture
model = BertBackbone.from_preset(
"bert_base_en_uncased",
load_weights=False
)
output = model(input_data)
```
"""
if preset not in cls.presets:
raise ValueError(
"`preset` must be one of "
f"""{", ".join(cls.presets)}. Received: {preset}."""
)
metadata = cls.presets[preset]
config = metadata["config"]
model = cls.from_config({**config, **kwargs})

if not load_weights:
return model

weights = keras.utils.get_file(
"model.h5",
metadata["weights_url"],
cache_subdir=os.path.join("models", preset),
file_hash=metadata["weights_hash"],
)

model.load_weights(weights)
return model
BertBackbone.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__
format_docstring(
model_name=BertBackbone.__name__,
example_preset_name="bert_base_en_uncased",
preset_names='", "'.join(BertBackbone.presets),
)(BertBackbone.from_preset.__func__)
82 changes: 18 additions & 64 deletions keras_nlp/models/deberta_v3/deberta_v3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
"""DeBERTa backbone model."""

import copy
import os

import tensorflow as tf
from tensorflow import keras

from keras_nlp.models.backbone import Backbone
from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets
from keras_nlp.models.deberta_v3.disentangled_attention_encoder import (
DisentangledAttentionEncoder,
Expand All @@ -34,7 +34,7 @@ def deberta_kernel_initializer(stddev=0.02):


@keras.utils.register_keras_serializable(package="keras_nlp")
class DebertaV3Backbone(keras.Model):
class DebertaV3Backbone(Backbone):
"""DeBERTa encoder network.
This network implements a bi-directional Transformer-based encoder as
Expand All @@ -45,7 +45,7 @@ class DebertaV3Backbone(keras.Model):
The default constructor gives a fully customizable, randomly initialized
DeBERTa encoder with any number of layers, heads, and embedding
dimensions. To load preset architectures and weights, use the `from_presets`
dimensions. To load preset architectures and weights, use the `from_preset`
constructor.
Disclaimer: Pre-trained models are provided on an "as is" basis, without
Expand Down Expand Up @@ -76,7 +76,13 @@ class DebertaV3Backbone(keras.Model):
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)),
}
# Randomly initialized DeBERTa model
# Pretrained DeBERTa encoder
model = keras_nlp.models.DebertaV3Backbone.from_preset(
"deberta_base_en",
)
output = model(input_data)
# Randomly initialized DeBERTa encoder with custom config
model = keras_nlp.models.DebertaV3Backbone(
vocabulary_size=128100,
num_layers=12,
Expand All @@ -86,7 +92,6 @@ class DebertaV3Backbone(keras.Model):
max_sequence_length=512,
bucket_size=256,
)
# Call the model on the input data.
output = model(input_data)
```
Expand Down Expand Up @@ -194,69 +199,18 @@ def get_config(self):
"trainable": self.trainable,
}

@classmethod
def from_config(cls, config):
return cls(**config)

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

@classmethod
@format_docstring(names=", ".join(backbone_presets))
def from_preset(
cls,
preset,
load_weights=True,
**kwargs,
):
"""Instantiate DeBERTa model from preset architecture and weights.
Args:
preset: string. Must be one of {{names}}.
load_weights: Whether to load pre-trained weights into model.
Defaults to `True`.
def from_preset(cls, preset, load_weights=True, **kwargs):
return super().from_preset(preset, load_weights, **kwargs)

Examples:
```python
input_data = {
"token_ids": tf.ones(shape=(1, 12), dtype=tf.int64),
"padding_mask": tf.constant(
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
),
}
# Load architecture and weights from preset
model = keras_nlp.models.DebertaV3Backbone.from_preset(
"deberta_base_en",
)
output = model(input_data)
# Load randomly initialized model from preset architecture
model = keras_nlp.models.DebertaV3Backbone.from_preset(
"deberta_base_en", load_weights=False
)
output = model(input_data)
```
"""
if preset not in cls.presets:
raise ValueError(
"`preset` must be one of "
f"""{", ".join(cls.presets)}. Received: {preset}."""
)
metadata = cls.presets[preset]
config = metadata["config"]
model = cls.from_config({**config, **kwargs})

if not load_weights:
return model

weights = keras.utils.get_file(
"model.h5",
metadata["weights_url"],
cache_subdir=os.path.join("models", preset),
file_hash=metadata["weights_hash"],
)

model.load_weights(weights)
return model
DebertaV3Backbone.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__
format_docstring(
model_name=DebertaV3Backbone.__name__,
example_preset_name="deberta_base_en",
preset_names='", "'.join(DebertaV3Backbone.presets),
)(DebertaV3Backbone.from_preset.__func__)
Loading

0 comments on commit eb76480

Please sign in to comment.