-
Notifications
You must be signed in to change notification settings - Fork 251
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a RoBERTa masked langauge model task
- Loading branch information
1 parent
73475c6
commit 88b5200
Showing
12 changed files
with
927 additions
and
171 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""RoBERTa classification model.""" | ||
|
||
import copy | ||
|
||
from tensorflow import keras | ||
|
||
from keras_nlp.layers.masked_lm_head import MaskedLMHead | ||
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone | ||
from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer | ||
from keras_nlp.models.roberta.roberta_masked_lm_preprocessor import ( | ||
RobertaMaskedLMPreprocessor, | ||
) | ||
from keras_nlp.models.roberta.roberta_presets import backbone_presets | ||
from keras_nlp.models.task import Task | ||
from keras_nlp.utils.python_utils import classproperty | ||
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_nlp") | ||
class RobertaMaskedLM(Task): | ||
"""An end-to-end RoBERTa model for the masked language modeling task. | ||
This model will train RoBERTa on a masked language modeling task. | ||
The model will predict labels for a number of masked tokens in the | ||
input data. For usage of this model with pre-trained weights, see the | ||
`from_preset()` method. | ||
This model can optionally be configured with a `preprocessor` layer, in | ||
which case inputs can be raw string features during `fit()`, `predict()`, | ||
and `evaluate()`. Inputs will be tokenized and dynamically masked during | ||
training and evaluation. This is done by default when creating the model | ||
with `from_preset()`. | ||
Disclaimer: Pre-trained models are provided on an "as is" basis, without | ||
warranties or conditions of any kind. The underlying model is provided by a | ||
third party and subject to a separate license, available | ||
[here](https://github.com/facebookresearch/fairseq). | ||
Args: | ||
backbone: A `keras_nlp.models.RobertaBackbone` instance. | ||
preprocessor: A `keras_nlp.models.RobertaMaskedLMPreprocessor` or | ||
`None`. If `None`, this model will not apply preprocessing, and | ||
inputs should be preprocessed before calling the model. | ||
Example usage: | ||
Raw string inputs and pretrained backbone. | ||
```python | ||
# Create a dataset with raw string features. Labels are inferred. | ||
features = ["The quick brown fox jumped.", "I forgot my homework."] | ||
# Create a RobertaMaskedLM with a pretrained backbone and further train | ||
# on an MLM task. | ||
classifier = keras_nlp.models.RobertaMaskedLM.from_preset( | ||
"roberta_base_en", | ||
) | ||
classifier.compile( | ||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
) | ||
classifier.fit(x=features, batch_size=2) | ||
``` | ||
Preprocessed inputs and custom backbone. | ||
```python | ||
# Create a preprocessed dataset where 0 is the mask token. | ||
preprocessed_features = { | ||
"token_ids": tf.constant( | ||
[[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8) | ||
), | ||
"padding_mask": tf.constant( | ||
[[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8) | ||
), | ||
"mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2)) | ||
} | ||
# Labels are the original masked values. | ||
labels = [[3, 5]] * 2 | ||
# Randomly initialize a RoBERTa encoder | ||
backbone = keras_nlp.models.RobertaBackbone( | ||
vocabulary_size=50265, | ||
num_layers=12, | ||
num_heads=12, | ||
hidden_dim=768, | ||
intermediate_dim=3072, | ||
max_sequence_length=12 | ||
) | ||
# Create a RoBERTa masked_lm and fit the data. | ||
masked_lm = keras_nlp.models.RobertaMaskedLM( | ||
backbone, | ||
preprocessor=None, | ||
) | ||
masked_lm.compile( | ||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
) | ||
masked_lm.fit(x=preprocessed_features, y=labels, batch_size=2) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
backbone, | ||
preprocessor=None, | ||
**kwargs, | ||
): | ||
inputs = { | ||
**backbone.input, | ||
"mask_positions": keras.Input( | ||
shape=(None,), dtype="int32", name="mask_positions" | ||
), | ||
} | ||
backbone_outputs = backbone(backbone.input) | ||
outputs = MaskedLMHead( | ||
vocabulary_size=backbone.vocabulary_size, | ||
embedding_weights=backbone.token_embedding.embeddings, | ||
intermediate_activation="gelu", | ||
kernel_initializer=roberta_kernel_initializer(), | ||
name="mlm_head", | ||
)(backbone_outputs, inputs["mask_positions"]) | ||
|
||
# Instantiate using Functional API Model constructor | ||
super().__init__( | ||
inputs=inputs, | ||
outputs=outputs, | ||
include_preprocessing=preprocessor is not None, | ||
**kwargs, | ||
) | ||
# All references to `self` below this line | ||
self._backbone = backbone | ||
self._preprocessor = preprocessor | ||
|
||
@classproperty | ||
def backbone_cls(cls): | ||
return RobertaBackbone | ||
|
||
@classproperty | ||
def preprocessor_cls(cls): | ||
return RobertaMaskedLMPreprocessor | ||
|
||
@classproperty | ||
def presets(cls): | ||
return copy.deepcopy(backbone_presets) |
184 changes: 184 additions & 0 deletions
184
keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""RoBERTa masked language model preprocessor layer.""" | ||
|
||
import copy | ||
|
||
from tensorflow import keras | ||
|
||
from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator | ||
from keras_nlp.models.preprocessor import Preprocessor | ||
from keras_nlp.models.roberta.roberta_multi_segment_packer import ( | ||
RobertaMultiSegmentPacker, | ||
) | ||
from keras_nlp.models.roberta.roberta_presets import backbone_presets | ||
from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer | ||
from keras_nlp.utils.keras_utils import ( | ||
convert_inputs_to_list_of_tensor_segments, | ||
) | ||
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight | ||
from keras_nlp.utils.python_utils import classproperty | ||
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_nlp") | ||
class RobertaMaskedLMPreprocessor(Preprocessor): | ||
"""RoBERTa preprocessing for the masked language modeling task. | ||
This preprocessing layer will prepare inputs for a masked language modeling | ||
task. It is primarily intended for use with the | ||
`keras_nlp.models.RobertaMaskedLM` task model. Preprocessing will occur in | ||
multiple steps. | ||
- Tokenize any number of input segments using the `tokenizer`. | ||
- Pack the inputs together with the appropriate `"<s>"`, `"</s>"` and | ||
`"<pad>"` tokens, i.e., adding a single `"<s>"` at the start of the | ||
entire sequence, `"</s></s>"` between each segment, | ||
and a `"</s>"` at the end of the entire sequence. | ||
- Randomly select non-special tokens to mask, controlled by | ||
`mask_selection_rate`. | ||
- Construct a `(x, y, sample_weight)` tuple suitable for training with a | ||
`keras_nlp.models.RobertaMaskedLM` task model. | ||
Args: | ||
tokenizer: A `keras_nlp.models.RobertaTokenizer` instance. | ||
sequence_length: The length of the packed inputs. | ||
mask_selection_rate: The probability an input token will be dynamically | ||
masked. | ||
mask_selection_length: The maximum number of masked tokens supported | ||
by the layer. | ||
truncate: string. The algorithm to truncate a list of batched segments | ||
to fit within `sequence_length`. The value can be either | ||
`round_robin` or `waterfall`: | ||
- `"round_robin"`: Available space is assigned one token at a | ||
time in a round-robin fashion to the inputs that still need | ||
some, until the limit is reached. | ||
- `"waterfall"`: The allocation of the budget is done using a | ||
"waterfall" algorithm that allocates quota in a | ||
left-to-right manner and fills up the buckets until we run | ||
out of budget. It supports an arbitrary number of segments. | ||
Examples: | ||
```python | ||
# Load the preprocessor from a preset. | ||
preprocessor = keras_nlp.models.RobertaMaskedLMPreprocessor.from_preset( | ||
"roberta_base_en" | ||
) | ||
# Tokenize and mask a single sentence. | ||
sentence = tf.constant("The quick brown fox jumped.") | ||
preprocessor(sentence) | ||
# Tokenize and mask a batch of sentences. | ||
sentences = tf.constant( | ||
["The quick brown fox jumped.", "Call me Ishmael."] | ||
) | ||
preprocessor(sentences) | ||
# Tokenize and mask a dataset of sentences. | ||
features = tf.constant( | ||
["The quick brown fox jumped.", "Call me Ishmael."] | ||
) | ||
ds = tf.data.Dataset.from_tensor_slices((features)) | ||
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
# Alternatively, you can create a preprocessor from your own vocabulary. | ||
# The usage is exactly the same as above. | ||
vocab = {"<s>": 0, "<pad>": 1, "</s>": 2, "<mask>": 3} | ||
vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6} | ||
merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick", "Ġ f", "o x", "Ġf ox"] | ||
tokenizer = keras_nlp.models.RobertaTokenizer( | ||
vocabulary=vocab, | ||
merges=merges, | ||
) | ||
preprocessor = keras_nlp.models.RobertaMaskedLMPreprocessor( | ||
tokenizer=tokenizer, | ||
sequence_length=8, | ||
) | ||
preprocessor("a quick fox") | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tokenizer, | ||
sequence_length=512, | ||
mask_selection_rate=0.15, | ||
mask_selection_length=96, | ||
truncate="round_robin", | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
|
||
self._tokenizer = tokenizer | ||
self.packer = RobertaMultiSegmentPacker( | ||
start_value=tokenizer.start_token_id, | ||
end_value=tokenizer.end_token_id, | ||
pad_value=tokenizer.pad_token_id, | ||
truncate=truncate, | ||
sequence_length=sequence_length, | ||
) | ||
self.masker = MaskedLMMaskGenerator( | ||
mask_selection_rate=mask_selection_rate, | ||
mask_selection_length=mask_selection_length, | ||
vocabulary_size=tokenizer.vocabulary_size(), | ||
mask_token_id=tokenizer.mask_token_id, | ||
unselectable_token_ids=[ | ||
tokenizer.start_token_id, | ||
tokenizer.end_token_id, | ||
tokenizer.pad_token_id, | ||
], | ||
) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"sequence_length": self.packer.sequence_length, | ||
"mask_selection_rate": self.masker.mask_selection_rate, | ||
"mask_selection_length": self.masker.mask_selection_length, | ||
"truncate": self.packer.truncate, | ||
} | ||
) | ||
return config | ||
|
||
def call(self, x, y=None, sample_weight=None): | ||
if y is not None: | ||
raise ValueError( | ||
"`RobertaMaskedLMPreprocessor` received labeled data (`y` is " | ||
"not `None`). No labels should be passed in as " | ||
"this layer generates training labels dynamically from raw " | ||
"text features passed as `x`. Received: `y={y}`." | ||
) | ||
|
||
x = convert_inputs_to_list_of_tensor_segments(x) | ||
x = [self.tokenizer(segment) for segment in x] | ||
token_ids = self.packer(x) | ||
masker_outputs = self.masker(token_ids) | ||
x = { | ||
"token_ids": masker_outputs["token_ids"], | ||
"padding_mask": token_ids != self.tokenizer.pad_token_id, | ||
"mask_positions": masker_outputs["mask_positions"], | ||
} | ||
y = masker_outputs["mask_ids"] | ||
sample_weight = masker_outputs["mask_weights"] | ||
return pack_x_y_sample_weight(x, y, sample_weight) | ||
|
||
@classproperty | ||
def tokenizer_cls(cls): | ||
return RobertaTokenizer | ||
|
||
@classproperty | ||
def presets(cls): | ||
return copy.deepcopy(backbone_presets) |
Oops, something went wrong.