Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Base Task Class #671

Merged
merged 7 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 87 additions & 161 deletions keras_nlp/models/bert/bert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""BERT classification model."""

import copy
import os

from tensorflow import keras

Expand All @@ -23,15 +22,12 @@
from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
from keras_nlp.models.bert.bert_presets import backbone_presets
from keras_nlp.models.bert.bert_presets import classifier_presets
from keras_nlp.utils.pipeline_model import PipelineModel
from keras_nlp.models.task import Task
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring

PRESET_NAMES = ", ".join(list(backbone_presets) + list(classifier_presets))


@keras.utils.register_keras_serializable(package="keras_nlp")
class BertClassifier(PipelineModel):
class BertClassifier(Task):
"""An end-to-end BERT model for classification tasks

This model attaches a classification head to a `keras_nlp.model.BertBackbone`
Expand All @@ -58,8 +54,9 @@ class BertClassifier(PipelineModel):

Examples:

Example usage.
```python
# Call classifier on the inputs.
# Define the preprocessed inputs.
preprocessed_features = {
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
"segment_ids": tf.constant(
Expand Down Expand Up @@ -95,6 +92,73 @@ class BertClassifier(PipelineModel):
# Access backbone programatically (e.g., to change `trainable`)
classifier.backbone.trainable = False
```

Raw string inputs.
```python
# Create a dataset with raw string features in an `(x, y)` format.
features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]

# Create a BertClassifier and fit your data.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=4,
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
classifier.fit(x=features, y=labels, batch_size=2)
```

Raw string inputs with customized preprocessing.
```python
# Create a dataset with raw string features in an `(x, y)` format.
features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]

# Use a shorter sequence length.
preprocessor = keras_nlp.models.BertPreprocessor.from_preset(
"bert_base_en_uncased",
sequence_length=128,
)

# Create a BertClassifier and fit your data.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=4,
preprocessor=preprocessor,
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
classifier.fit(x=features, y=labels, batch_size=2)
```

Preprocessed inputs.
```python
# Create a dataset with preprocessed features in an `(x, y)` format.
preprocessed_features = {
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
"segment_ids": tf.constant(
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
),
}
labels = [0, 3]

# Create a BERT classifier and fit your data.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=4,
preprocessor=None,
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
```
"""

def __init__(
Expand Down Expand Up @@ -124,164 +188,26 @@ def __init__(
self._backbone = backbone
self._preprocessor = preprocessor
self.num_classes = num_classes

def preprocess_samples(self, x, y=None, sample_weight=None):
return self.preprocessor(x, y=y, sample_weight=sample_weight)

@property
def backbone(self):
"""A `keras_nlp.models.BertBackbone` instance providing the encoder
submodel.
"""
return self._backbone

@property
def preprocessor(self):
"""A `keras_nlp.models.BertPreprocessor` for preprocessing inputs."""
return self._preprocessor
self.dropout = dropout

def get_config(self):
return {
"backbone": keras.layers.serialize(self.backbone),
"preprocessor": keras.layers.serialize(self.preprocessor),
"num_classes": self.num_classes,
"name": self.name,
"trainable": self.trainable,
}
config = super().get_config()
config.update(
{
"num_classes": self.num_classes,
"dropout": self.dropout,
}
)
return config

@classproperty
def backbone_cls(cls):
return BertBackbone

@classmethod
def from_config(cls, config):
if "backbone" in config and isinstance(config["backbone"], dict):
config["backbone"] = keras.layers.deserialize(config["backbone"])
if "preprocessor" in config and isinstance(
config["preprocessor"], dict
):
config["preprocessor"] = keras.layers.deserialize(
config["preprocessor"]
)
return cls(**config)
@classproperty
def preprocessor_cls(cls):
return BertPreprocessor

@classproperty
def presets(cls):
return copy.deepcopy({**backbone_presets, **classifier_presets})

@classmethod
@format_docstring(names=PRESET_NAMES)
def from_preset(
cls,
preset,
load_weights=True,
**kwargs,
):
"""Create a classification model from a preset architecture and weights.

By default, this method will automatically create a `preprocessor`
layer to preprocess raw inputs during `fit()`, `predict()`, and
`evaluate()`. If you would like to disable this behavior, pass
`preprocessor=None`.

Args:
preset: string. Must be one of {{names}}.
load_weights: Whether to load pre-trained weights into model.
Defaults to `True`.

Examples:

Raw string inputs.
```python
# Create a dataset with raw string features in an `(x, y)` format.
features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]

# Create a BertClassifier and fit your data.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=4,
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
classifier.fit(x=features, y=labels, batch_size=2)
```

Raw string inputs with customized preprocessing.
jbischof marked this conversation as resolved.
Show resolved Hide resolved
```python
# Create a dataset with raw string features in an `(x, y)` format.
features = ["The quick brown fox jumped.", "I forgot my homework."]
labels = [0, 3]

# Use a shorter sequence length.
preprocessor = keras_nlp.models.BertPreprocessor.from_preset(
"bert_base_en_uncased",
sequence_length=128,
)

# Create a BertClassifier and fit your data.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=4,
preprocessor=preprocessor,
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
classifier.fit(x=features, y=labels, batch_size=2)
```

Preprocessed inputs.
```python
# Create a dataset with preprocessed features in an `(x, y)` format.
preprocessed_features = {
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
"segment_ids": tf.constant(
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
),
}
labels = [0, 3]

# Create a BERT classifier and fit your data.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_en_uncased",
num_classes=4,
preprocessor=None,
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
```
"""
if preset not in cls.presets:
raise ValueError(
"`preset` must be one of "
f"""{", ".join(cls.presets)}. Received: {preset}."""
)

if "preprocessor" not in kwargs:
kwargs["preprocessor"] = BertPreprocessor.from_preset(preset)

# Check if preset is backbone-only model
if preset in BertBackbone.presets:
backbone = BertBackbone.from_preset(preset, load_weights)
return cls(backbone, **kwargs)

# Otherwise must be one of class presets
metadata = cls.presets[preset]
config = metadata["config"]
model = cls.from_config({**config, **kwargs})

if not load_weights:
return model

weights = keras.utils.get_file(
"model.h5",
metadata["weights_url"],
cache_subdir=os.path.join("models", preset),
file_hash=metadata["weights_hash"],
)

model.load_weights(weights)
return model
Loading