Skip to content

Commit

Permalink
Add a RoBERTa masked langauge model task
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Jan 25, 2023
1 parent 73475c6 commit 88b5200
Show file tree
Hide file tree
Showing 12 changed files with 927 additions and 171 deletions.
4 changes: 4 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
from keras_nlp.models.roberta.roberta_masked_lm import RobertaMaskedLM
from keras_nlp.models.roberta.roberta_masked_lm_preprocessor import (
RobertaMaskedLMPreprocessor,
)
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/roberta/roberta_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def setUp(self):
"Ġis": 9,
"Ġthe": 10,
"Ġbest": 11,
"<mask>": 12,
}

merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "Ġa i", "p l", "n e"]
Expand Down
153 changes: 153 additions & 0 deletions keras_nlp/models/roberta/roberta_masked_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RoBERTa classification model."""

import copy

from tensorflow import keras

from keras_nlp.layers.masked_lm_head import MaskedLMHead
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer
from keras_nlp.models.roberta.roberta_masked_lm_preprocessor import (
RobertaMaskedLMPreprocessor,
)
from keras_nlp.models.roberta.roberta_presets import backbone_presets
from keras_nlp.models.task import Task
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class RobertaMaskedLM(Task):
"""An end-to-end RoBERTa model for the masked language modeling task.
This model will train RoBERTa on a masked language modeling task.
The model will predict labels for a number of masked tokens in the
input data. For usage of this model with pre-trained weights, see the
`from_preset()` method.
This model can optionally be configured with a `preprocessor` layer, in
which case inputs can be raw string features during `fit()`, `predict()`,
and `evaluate()`. Inputs will be tokenized and dynamically masked during
training and evaluation. This is done by default when creating the model
with `from_preset()`.
Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
third party and subject to a separate license, available
[here](https://github.com/facebookresearch/fairseq).
Args:
backbone: A `keras_nlp.models.RobertaBackbone` instance.
preprocessor: A `keras_nlp.models.RobertaMaskedLMPreprocessor` or
`None`. If `None`, this model will not apply preprocessing, and
inputs should be preprocessed before calling the model.
Example usage:
Raw string inputs and pretrained backbone.
```python
# Create a dataset with raw string features. Labels are inferred.
features = ["The quick brown fox jumped.", "I forgot my homework."]
# Create a RobertaMaskedLM with a pretrained backbone and further train
# on an MLM task.
classifier = keras_nlp.models.RobertaMaskedLM.from_preset(
"roberta_base_en",
)
classifier.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
classifier.fit(x=features, batch_size=2)
```
Preprocessed inputs and custom backbone.
```python
# Create a preprocessed dataset where 0 is the mask token.
preprocessed_features = {
"token_ids": tf.constant(
[[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8)
),
"mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2))
}
# Labels are the original masked values.
labels = [[3, 5]] * 2
# Randomly initialize a RoBERTa encoder
backbone = keras_nlp.models.RobertaBackbone(
vocabulary_size=50265,
num_layers=12,
num_heads=12,
hidden_dim=768,
intermediate_dim=3072,
max_sequence_length=12
)
# Create a RoBERTa masked_lm and fit the data.
masked_lm = keras_nlp.models.RobertaMaskedLM(
backbone,
preprocessor=None,
)
masked_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
masked_lm.fit(x=preprocessed_features, y=labels, batch_size=2)
```
"""

def __init__(
self,
backbone,
preprocessor=None,
**kwargs,
):
inputs = {
**backbone.input,
"mask_positions": keras.Input(
shape=(None,), dtype="int32", name="mask_positions"
),
}
backbone_outputs = backbone(backbone.input)
outputs = MaskedLMHead(
vocabulary_size=backbone.vocabulary_size,
embedding_weights=backbone.token_embedding.embeddings,
intermediate_activation="gelu",
kernel_initializer=roberta_kernel_initializer(),
name="mlm_head",
)(backbone_outputs, inputs["mask_positions"])

# Instantiate using Functional API Model constructor
super().__init__(
inputs=inputs,
outputs=outputs,
include_preprocessing=preprocessor is not None,
**kwargs,
)
# All references to `self` below this line
self._backbone = backbone
self._preprocessor = preprocessor

@classproperty
def backbone_cls(cls):
return RobertaBackbone

@classproperty
def preprocessor_cls(cls):
return RobertaMaskedLMPreprocessor

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
184 changes: 184 additions & 0 deletions keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""RoBERTa masked language model preprocessor layer."""

import copy

from tensorflow import keras

from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.models.roberta.roberta_multi_segment_packer import (
RobertaMultiSegmentPacker,
)
from keras_nlp.models.roberta.roberta_presets import backbone_presets
from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class RobertaMaskedLMPreprocessor(Preprocessor):
"""RoBERTa preprocessing for the masked language modeling task.
This preprocessing layer will prepare inputs for a masked language modeling
task. It is primarily intended for use with the
`keras_nlp.models.RobertaMaskedLM` task model. Preprocessing will occur in
multiple steps.
- Tokenize any number of input segments using the `tokenizer`.
- Pack the inputs together with the appropriate `"<s>"`, `"</s>"` and
`"<pad>"` tokens, i.e., adding a single `"<s>"` at the start of the
entire sequence, `"</s></s>"` between each segment,
and a `"</s>"` at the end of the entire sequence.
- Randomly select non-special tokens to mask, controlled by
`mask_selection_rate`.
- Construct a `(x, y, sample_weight)` tuple suitable for training with a
`keras_nlp.models.RobertaMaskedLM` task model.
Args:
tokenizer: A `keras_nlp.models.RobertaTokenizer` instance.
sequence_length: The length of the packed inputs.
mask_selection_rate: The probability an input token will be dynamically
masked.
mask_selection_length: The maximum number of masked tokens supported
by the layer.
truncate: string. The algorithm to truncate a list of batched segments
to fit within `sequence_length`. The value can be either
`round_robin` or `waterfall`:
- `"round_robin"`: Available space is assigned one token at a
time in a round-robin fashion to the inputs that still need
some, until the limit is reached.
- `"waterfall"`: The allocation of the budget is done using a
"waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run
out of budget. It supports an arbitrary number of segments.
Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.RobertaMaskedLMPreprocessor.from_preset(
"roberta_base_en"
)
# Tokenize and mask a single sentence.
sentence = tf.constant("The quick brown fox jumped.")
preprocessor(sentence)
# Tokenize and mask a batch of sentences.
sentences = tf.constant(
["The quick brown fox jumped.", "Call me Ishmael."]
)
preprocessor(sentences)
# Tokenize and mask a dataset of sentences.
features = tf.constant(
["The quick brown fox jumped.", "Call me Ishmael."]
)
ds = tf.data.Dataset.from_tensor_slices((features))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
# Alternatively, you can create a preprocessor from your own vocabulary.
# The usage is exactly the same as above.
vocab = {"<s>": 0, "<pad>": 1, "</s>": 2, "<mask>": 3}
vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6}
merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick", "Ġ f", "o x", "Ġf ox"]
tokenizer = keras_nlp.models.RobertaTokenizer(
vocabulary=vocab,
merges=merges,
)
preprocessor = keras_nlp.models.RobertaMaskedLMPreprocessor(
tokenizer=tokenizer,
sequence_length=8,
)
preprocessor("a quick fox")
```
"""

def __init__(
self,
tokenizer,
sequence_length=512,
mask_selection_rate=0.15,
mask_selection_length=96,
truncate="round_robin",
**kwargs,
):
super().__init__(**kwargs)

self._tokenizer = tokenizer
self.packer = RobertaMultiSegmentPacker(
start_value=tokenizer.start_token_id,
end_value=tokenizer.end_token_id,
pad_value=tokenizer.pad_token_id,
truncate=truncate,
sequence_length=sequence_length,
)
self.masker = MaskedLMMaskGenerator(
mask_selection_rate=mask_selection_rate,
mask_selection_length=mask_selection_length,
vocabulary_size=tokenizer.vocabulary_size(),
mask_token_id=tokenizer.mask_token_id,
unselectable_token_ids=[
tokenizer.start_token_id,
tokenizer.end_token_id,
tokenizer.pad_token_id,
],
)

def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.packer.sequence_length,
"mask_selection_rate": self.masker.mask_selection_rate,
"mask_selection_length": self.masker.mask_selection_length,
"truncate": self.packer.truncate,
}
)
return config

def call(self, x, y=None, sample_weight=None):
if y is not None:
raise ValueError(
"`RobertaMaskedLMPreprocessor` received labeled data (`y` is "
"not `None`). No labels should be passed in as "
"this layer generates training labels dynamically from raw "
"text features passed as `x`. Received: `y={y}`."
)

x = convert_inputs_to_list_of_tensor_segments(x)
x = [self.tokenizer(segment) for segment in x]
token_ids = self.packer(x)
masker_outputs = self.masker(token_ids)
x = {
"token_ids": masker_outputs["token_ids"],
"padding_mask": token_ids != self.tokenizer.pad_token_id,
"mask_positions": masker_outputs["mask_positions"],
}
y = masker_outputs["mask_ids"]
sample_weight = masker_outputs["mask_weights"]
return pack_x_y_sample_weight(x, y, sample_weight)

@classproperty
def tokenizer_cls(cls):
return RobertaTokenizer

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
Loading

0 comments on commit 88b5200

Please sign in to comment.