From 91fe6bd380d5dec450976bcb79f3f07ece9aebf6 Mon Sep 17 00:00:00 2001
From: Maaz Karim <maazkarim02@gmail.com>
Date: Sat, 4 Mar 2023 01:57:28 +0530
Subject: [PATCH] BertMaskedLM Task Model and Preprocessor (#774)

* bert_masekd_lm init

* Merge branch 'master' into BertMaskedLM

* WIP : BERT MASKED LM

* Added Tests

* Black Formatting

* Fixed Format

* Fixed formatting

* black + lint.sh

* Reformat codew

* Updated Docstring for bert_tokenizer

* Updated masked_lm_generator.py

* fixed linting

* Changed Boolean Variables tp Numeric

* Formatted using shell/format.sh

* Updated bert_masked_lm.py

* typo fix

---------

Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com>
---
 keras_nlp/models/__init__.py                  |   4 +
 keras_nlp/models/bert/bert_masked_lm.py       | 152 ++++++++++++++++++
 .../bert/bert_masked_lm_preprocessor.py       | 138 ++++++++++++++++
 .../bert/bert_masked_lm_preprocessor_test.py  | 146 +++++++++++++++++
 keras_nlp/models/bert/bert_masked_lm_test.py  | 131 +++++++++++++++
 keras_nlp/models/bert/bert_tokenizer.py       |  14 +-
 6 files changed, 579 insertions(+), 6 deletions(-)
 create mode 100644 keras_nlp/models/bert/bert_masked_lm.py
 create mode 100644 keras_nlp/models/bert/bert_masked_lm_preprocessor.py
 create mode 100644 keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py
 create mode 100644 keras_nlp/models/bert/bert_masked_lm_test.py

diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py
index 7751b37165..ceac30391a 100644
--- a/keras_nlp/models/__init__.py
+++ b/keras_nlp/models/__init__.py
@@ -24,6 +24,10 @@
 from keras_nlp.models.bart.bart_tokenizer import BartTokenizer
 from keras_nlp.models.bert.bert_backbone import BertBackbone
 from keras_nlp.models.bert.bert_classifier import BertClassifier
+from keras_nlp.models.bert.bert_masked_lm import BertMaskedLM
+from keras_nlp.models.bert.bert_masked_lm_preprocessor import (
+    BertMaskedLMPreprocessor,
+)
 from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
 from keras_nlp.models.bert.bert_tokenizer import BertTokenizer
 from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone
diff --git a/keras_nlp/models/bert/bert_masked_lm.py b/keras_nlp/models/bert/bert_masked_lm.py
new file mode 100644
index 0000000000..fd9b29f9eb
--- /dev/null
+++ b/keras_nlp/models/bert/bert_masked_lm.py
@@ -0,0 +1,152 @@
+# Copyright 2022 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""BERT masked LM model."""
+
+import copy
+
+from tensorflow import keras
+
+from keras_nlp.layers.masked_lm_head import MaskedLMHead
+from keras_nlp.models.bert.bert_backbone import BertBackbone
+from keras_nlp.models.bert.bert_backbone import bert_kernel_initializer
+from keras_nlp.models.bert.bert_masked_lm_preprocessor import (
+    BertMaskedLMPreprocessor,
+)
+from keras_nlp.models.bert.bert_presets import backbone_presets
+from keras_nlp.models.task import Task
+from keras_nlp.utils.python_utils import classproperty
+
+
+@keras.utils.register_keras_serializable(package="keras_nlp")
+class BertMaskedLM(Task):
+    """An end-to-end BERT model for the masked language modeling task.
+
+    This model will train BERT on a masked language modeling task.
+    The model will predict labels for a number of masked tokens in the
+    input data. For usage of this model with pre-trained weights, see the
+    `from_preset()` method.
+
+    This model can optionally be configured with a `preprocessor` layer, in
+    which case inputs can be raw string features during `fit()`, `predict()`,
+    and `evaluate()`. Inputs will be tokenized and dynamically masked during
+    training and evaluation. This is done by default when creating the model
+    with `from_preset()`.
+
+    Disclaimer: Pre-trained models are provided on an "as is" basis, without
+    warranties or conditions of any kind.
+
+    Args:
+        backbone: A `keras_nlp.models.BertBackbone` instance.
+        preprocessor: A `keras_nlp.models.BertMaskedLMPreprocessor` or
+            `None`. If `None`, this model will not apply preprocessing, and
+            inputs should be preprocessed before calling the model.
+
+    Example usage:
+
+    Raw string inputs and pretrained backbone.
+    ```python
+    # Create a dataset with raw string features. Labels are inferred.
+    features = ["The quick brown fox jumped.", "I forgot my homework."]
+
+    # Create a BertMaskedLM with a pretrained backbone and further train
+    # on an MLM task.
+    masked_lm = keras_nlp.models.BertMaskedLM.from_preset(
+        "bert_base_en",
+    )
+    masked_lm.compile(
+        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+    )
+    masked_lm.fit(x=features, batch_size=2)
+    ```
+
+    Preprocessed inputs and custom backbone.
+    ```python
+    # Create a preprocessed dataset where 0 is the mask token.
+    preprocessed_features = {
+        "token_ids": tf.constant(
+            [[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8)
+        ),
+        "padding_mask": tf.constant(
+            [[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8)
+        ),
+        "mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2)),
+        "segment_ids": tf.constant([[0, 0, 0, 0, 0, 0, 0, 0]] * 2, shape=(2, 8))
+    }
+    # Labels are the original masked values.
+    labels = [[3, 5]] * 2
+
+    # Randomly initialize a BERT encoder
+    backbone = keras_nlp.models.BertBackbone(
+        vocabulary_size=50265,
+        num_layers=12,
+        num_heads=12,
+        hidden_dim=768,
+        intermediate_dim=3072,
+        max_sequence_length=12
+    )
+    # Create a BERT masked LM model and fit the data.
+    masked_lm = keras_nlp.models.BertMaskedLM(
+        backbone,
+        preprocessor=None,
+    )
+    masked_lm.compile(
+        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+    )
+    masked_lm.fit(x=preprocessed_features, y=labels, batch_size=2)
+    ```
+    """
+
+    def __init__(
+        self,
+        backbone,
+        preprocessor=None,
+        **kwargs,
+    ):
+        inputs = {
+            **backbone.input,
+            "mask_positions": keras.Input(
+                shape=(None,), dtype="int32", name="mask_positions"
+            ),
+        }
+        backbone_outputs = backbone(backbone.input)
+        outputs = MaskedLMHead(
+            vocabulary_size=backbone.vocabulary_size,
+            embedding_weights=backbone.token_embedding.embeddings,
+            intermediate_activation="gelu",
+            kernel_initializer=bert_kernel_initializer(),
+            name="mlm_head",
+        )(backbone_outputs["sequence_output"], inputs["mask_positions"])
+
+        # Instantiate using Functional API Model constructor
+        super().__init__(
+            inputs=inputs,
+            outputs=outputs,
+            include_preprocessing=preprocessor is not None,
+            **kwargs,
+        )
+        # All references to `self` below this line
+        self.backbone = backbone
+        self.preprocessor = preprocessor
+
+    @classproperty
+    def backbone_cls(cls):
+        return BertBackbone
+
+    @classproperty
+    def preprocessor_cls(cls):
+        return BertMaskedLMPreprocessor
+
+    @classproperty
+    def presets(cls):
+        return copy.deepcopy(backbone_presets)
diff --git a/keras_nlp/models/bert/bert_masked_lm_preprocessor.py b/keras_nlp/models/bert/bert_masked_lm_preprocessor.py
new file mode 100644
index 0000000000..0d52ed78ec
--- /dev/null
+++ b/keras_nlp/models/bert/bert_masked_lm_preprocessor.py
@@ -0,0 +1,138 @@
+# Copyright 2022 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BERT masked language model preprocessor layer."""
+
+from absl import logging
+from tensorflow import keras
+
+from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator
+from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
+from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
+
+
+@keras.utils.register_keras_serializable(package="keras_nlp")
+class BertMaskedLMPreprocessor(BertPreprocessor):
+    """BERT preprocessing for the masked language modeling task.
+    This preprocessing layer will prepare inputs for a masked language modeling
+    task. It is primarily intended for use with the
+    `keras_nlp.models.BertMaskedLM` task model. Preprocessing will occur in
+    multiple steps.
+    - Tokenize any number of input segments using the `tokenizer`.
+    - Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`.
+       with the appropriate `"[CLS]"`, `"[SEP]"`, `"[SEP]"`, `"[SEP]"` and `"[PAD]"` tokens.
+    - Randomly select non-special tokens to mask, controlled by
+      `mask_selection_rate`.
+    - Construct a `(x, y, sample_weight)` tuple suitable for training with a
+      `keras_nlp.models.BertMaskedLM` task model.
+    Examples:
+    ```python
+    # Load the preprocessor from a preset.
+    preprocessor = keras_nlp.models.BertMaskedLMPreprocessor.from_preset(
+        "bert_base_en"
+    )
+    # Tokenize and mask a single sentence.
+    sentence = tf.constant("The quick brown fox jumped.")
+    preprocessor(sentence)
+    # Tokenize and mask a batch of sentences.
+    sentences = tf.constant(
+        ["The quick brown fox jumped.", "Call me Ishmael."]
+    )
+    preprocessor(sentences)
+    # Tokenize and mask a dataset of sentences.
+    features = tf.constant(
+        ["The quick brown fox jumped.", "Call me Ishmael."]
+    )
+    ds = tf.data.Dataset.from_tensor_slices((features))
+    ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
+    # Alternatively, you can create a preprocessor from your own vocabulary.
+    # The usage is exactly the same as above.
+    vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
+    vocab += ["THE", "QUICK", "BROWN", "FOX"]
+    vocab += ["Call", "me", "Ishmael"]
+    tokenizer = keras_nlp.models.BertTokenizer(vocabulary=vocab)
+    preprocessor = keras_nlp.models.BertMaskedLMPreprocessor(tokenizer)
+    ```
+    """
+
+    def __init__(
+        self,
+        tokenizer,
+        sequence_length=512,
+        truncate="round_robin",
+        mask_selection_rate=0.15,
+        mask_selection_length=96,
+        mask_token_rate=0.8,
+        random_token_rate=0.1,
+        **kwargs,
+    ):
+        super().__init__(
+            tokenizer,
+            sequence_length=sequence_length,
+            truncate=truncate,
+            **kwargs,
+        )
+
+        self.masker = MaskedLMMaskGenerator(
+            mask_selection_rate=mask_selection_rate,
+            mask_selection_length=mask_selection_length,
+            mask_token_rate=mask_token_rate,
+            random_token_rate=random_token_rate,
+            vocabulary_size=tokenizer.vocabulary_size(),
+            mask_token_id=tokenizer.mask_token_id,
+            unselectable_token_ids=[
+                tokenizer.cls_token_id,
+                tokenizer.sep_token_id,
+                tokenizer.pad_token_id,
+            ],
+        )
+
+    def get_config(self):
+        config = super().get_config()
+        config.update(
+            {
+                "mask_selection_rate": self.masker.mask_selection_rate,
+                "mask_selection_length": self.masker.mask_selection_length,
+                "mask_token_rate": self.masker.mask_token_rate,
+                "random_token_rate": self.masker.random_token_rate,
+            }
+        )
+        return config
+
+    def call(self, x, y=None, sample_weight=None):
+        if y is not None or sample_weight is not None:
+            logging.warning(
+                f"{self.__class__.__name__} generates `y` and `sample_weight` "
+                "based on your input data, but your data already contains `y` "
+                "or `sample_weight`. Your `y` and `sample_weight` will be "
+                "ignored."
+            )
+
+        x = super().call(x)
+
+        token_ids, padding_mask, segment_ids = (
+            x["token_ids"],
+            x["padding_mask"],
+            x["segment_ids"],
+        )
+        masker_outputs = self.masker(token_ids)
+        x = {
+            "token_ids": masker_outputs["token_ids"],
+            "padding_mask": padding_mask,
+            "segment_ids": segment_ids,
+            "mask_positions": masker_outputs["mask_positions"],
+        }
+        y = masker_outputs["mask_ids"]
+        sample_weight = masker_outputs["mask_weights"]
+        return pack_x_y_sample_weight(x, y, sample_weight)
diff --git a/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py b/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py
new file mode 100644
index 0000000000..d88f667574
--- /dev/null
+++ b/keras_nlp/models/bert/bert_masked_lm_preprocessor_test.py
@@ -0,0 +1,146 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for BERT masked language model preprocessor layer."""
+
+import os
+
+import tensorflow as tf
+from absl.testing import parameterized
+from tensorflow import keras
+
+from keras_nlp.models.bert.bert_masked_lm_preprocessor import (
+    BertMaskedLMPreprocessor,
+)
+from keras_nlp.models.bert.bert_tokenizer import BertTokenizer
+
+
+class BertMaskedLMPreprocessorTest(tf.test.TestCase, parameterized.TestCase):
+    def setUp(self):
+        self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
+        self.vocab += ["THE", "QUICK", "BROWN", "FOX"]
+        self.vocab += ["the", "quick", "brown", "fox"]
+
+        tokenizer = BertTokenizer(vocabulary=self.vocab)
+
+        self.preprocessor = BertMaskedLMPreprocessor(
+            tokenizer=tokenizer,
+            # Simplify out testing by masking every available token.
+            mask_selection_rate=1.0,
+            mask_token_rate=1.0,
+            random_token_rate=0.0,
+            mask_selection_length=4,
+            sequence_length=12,
+        )
+
+    def test_preprocess_strings(self):
+        input_data = "the quick brown fox"
+
+        x, y, sw = self.preprocessor(input_data)
+        self.assertAllEqual(
+            x["token_ids"], [2, 4, 4, 4, 4, 3, 0, 0, 0, 0, 0, 0]
+        )
+        self.assertAllEqual(
+            x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
+        )
+        self.assertAllEqual(
+            x["segment_ids"], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+        )
+        self.assertAllEqual(x["mask_positions"], [1, 2, 3, 4])
+        self.assertAllEqual(y, [9, 10, 11, 12])
+        self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0])
+
+    def test_preprocess_list_of_strings(self):
+        input_data = ["the quick brown fox"] * 4
+
+        x, y, sw = self.preprocessor(input_data)
+        self.assertAllEqual(
+            x["token_ids"], [[2, 4, 4, 4, 4, 3, 0, 0, 0, 0, 0, 0]] * 4
+        )
+        self.assertAllEqual(
+            x["padding_mask"],
+            [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4,
+        )
+        self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4]] * 4)
+        self.assertAllEqual(y, [[9, 10, 11, 12]] * 4)
+        self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0]] * 4)
+
+    def test_preprocess_dataset(self):
+        sentences = tf.constant(["the quick brown fox"] * 4)
+        ds = tf.data.Dataset.from_tensor_slices(sentences)
+        ds = ds.map(self.preprocessor)
+        x, y, sw = ds.batch(4).take(1).get_single_element()
+        self.assertAllEqual(
+            x["token_ids"], [[2, 4, 4, 4, 4, 3, 0, 0, 0, 0, 0, 0]] * 4
+        )
+        self.assertAllEqual(
+            x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4
+        )
+        self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4]] * 4)
+        self.assertAllEqual(y, [[9, 10, 11, 12]] * 4)
+        self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0]] * 4)
+
+    def test_mask_multiple_sentences(self):
+        sentence_one = tf.constant("the quick")
+        sentence_two = tf.constant("brown fox")
+
+        x, y, sw = self.preprocessor((sentence_one, sentence_two))
+        self.assertAllEqual(
+            x["token_ids"], [2, 4, 4, 3, 4, 4, 3, 0, 0, 0, 0, 0]
+        )
+        self.assertAllEqual(
+            x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
+        )
+        self.assertAllEqual(x["mask_positions"], [1, 2, 4, 5])
+        self.assertAllEqual(y, [9, 10, 11, 12])
+        self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0])
+
+    def test_no_masking_zero_rate(self):
+        no_mask_preprocessor = BertMaskedLMPreprocessor(
+            self.preprocessor.tokenizer,
+            mask_selection_rate=0.0,
+            mask_selection_length=4,
+            sequence_length=12,
+        )
+        input_data = "the quick brown fox"
+
+        x, y, sw = no_mask_preprocessor(input_data)
+        self.assertAllEqual(
+            x["token_ids"], [2, 9, 10, 11, 12, 3, 0, 0, 0, 0, 0, 0]
+        )
+        self.assertAllEqual(
+            x["padding_mask"],
+            [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
+        )
+        self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0])
+        self.assertAllEqual(y, [0, 0, 0, 0])
+        self.assertAllEqual(sw, [0.0, 0.0, 0.0, 0.0])
+
+    @parameterized.named_parameters(
+        ("tf_format", "tf", "model"),
+        ("keras_format", "keras_v3", "model.keras"),
+    )
+    def test_saved_model(self, save_format, filename):
+        input_data = tf.constant(["the quick brown fox"])
+
+        inputs = keras.Input(dtype="string", shape=())
+        outputs = self.preprocessor(inputs)
+        model = keras.Model(inputs, outputs)
+
+        path = os.path.join(self.get_temp_dir(), filename)
+        model.save(path, save_format=save_format)
+
+        restored_model = keras.models.load_model(path)
+        outputs = model(input_data)[0]["token_ids"]
+        restored_outputs = restored_model(input_data)[0]["token_ids"]
+        self.assertAllEqual(outputs, restored_outputs)
diff --git a/keras_nlp/models/bert/bert_masked_lm_test.py b/keras_nlp/models/bert/bert_masked_lm_test.py
new file mode 100644
index 0000000000..0fa10ce326
--- /dev/null
+++ b/keras_nlp/models/bert/bert_masked_lm_test.py
@@ -0,0 +1,131 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for BERT masked language model."""
+
+import os
+
+import tensorflow as tf
+from absl.testing import parameterized
+from tensorflow import keras
+
+from keras_nlp.models.bert.bert_backbone import BertBackbone
+from keras_nlp.models.bert.bert_masked_lm import BertMaskedLM
+from keras_nlp.models.bert.bert_masked_lm_preprocessor import (
+    BertMaskedLMPreprocessor,
+)
+from keras_nlp.models.bert.bert_tokenizer import BertTokenizer
+
+
+class BertMaskedLMTest(tf.test.TestCase, parameterized.TestCase):
+    def setUp(self):
+        self.backbone = BertBackbone(
+            vocabulary_size=1000,
+            num_layers=2,
+            num_heads=2,
+            hidden_dim=64,
+            intermediate_dim=128,
+            max_sequence_length=128,
+        )
+
+        self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
+        self.vocab += ["THE", "QUICK", "BROWN", "FOX"]
+        self.vocab += ["the", "quick", "brown", "fox"]
+
+        tokenizer = BertTokenizer(vocabulary=self.vocab)
+
+        self.preprocessor = BertMaskedLMPreprocessor(
+            tokenizer=tokenizer,
+            # Simplify out testing by masking every available token.
+            mask_selection_rate=1.0,
+            mask_token_rate=1.0,
+            random_token_rate=0.0,
+            mask_selection_length=2,
+            sequence_length=10,
+        )
+        self.masked_lm = BertMaskedLM(
+            self.backbone,
+            preprocessor=self.preprocessor,
+        )
+        self.masked_lm_no_preprocessing = BertMaskedLM(
+            self.backbone,
+            preprocessor=None,
+        )
+
+        self.raw_batch = tf.constant(
+            [
+                "quick brown fox",
+                "eagle flew over fox",
+                "the eagle flew quick",
+                "a brown eagle",
+            ]
+        )
+        self.preprocessed_batch = self.preprocessor(self.raw_batch)[0]
+        self.raw_dataset = tf.data.Dataset.from_tensor_slices(
+            self.raw_batch
+        ).batch(2)
+        self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor)
+
+    def test_valid_call_masked_lm(self):
+        self.masked_lm(self.preprocessed_batch)
+
+    @parameterized.named_parameters(
+        ("jit_compile_false", False), ("jit_compile_true", True)
+    )
+    def test_bert_masked_lm_predict(self, jit_compile):
+        self.masked_lm.compile(jit_compile=jit_compile)
+        self.masked_lm.predict(self.raw_batch)
+
+    @parameterized.named_parameters(
+        ("jit_compile_false", False), ("jit_compile_true", True)
+    )
+    def test_bert_masked_lm_predict_no_preprocessing(self, jit_compile):
+        self.masked_lm_no_preprocessing.compile(jit_compile=jit_compile)
+        self.masked_lm_no_preprocessing.predict(self.preprocessed_batch)
+
+    @parameterized.named_parameters(
+        ("jit_compile_false", False), ("jit_compile_true", True)
+    )
+    def test_bert_masked_lm_fit(self, jit_compile):
+        self.masked_lm.compile(
+            loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+            jit_compile=jit_compile,
+        )
+        self.masked_lm.fit(self.raw_dataset)
+
+    @parameterized.named_parameters(
+        ("jit_compile_false", False), ("jit_compile_true", True)
+    )
+    def test_bert_masked_lm_fit_no_preprocessing(self, jit_compile):
+        self.masked_lm_no_preprocessing.compile(
+            loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+            jit_compile=jit_compile,
+        )
+        self.masked_lm_no_preprocessing.fit(self.preprocessed_dataset)
+
+    @parameterized.named_parameters(
+        ("tf_format", "tf", "model"),
+        ("keras_format", "keras_v3", "model.keras"),
+    )
+    def test_saved_model(self, save_format, filename):
+        save_path = os.path.join(self.get_temp_dir(), filename)
+        self.masked_lm.save(save_path, save_format=save_format)
+        restored_model = keras.models.load_model(save_path)
+
+        # Check we got the real object back.
+        self.assertIsInstance(restored_model, BertMaskedLM)
+
+        model_output = self.masked_lm(self.preprocessed_batch)
+        restored_output = restored_model(self.preprocessed_batch)
+
+        self.assertAllClose(model_output, restored_output)
diff --git a/keras_nlp/models/bert/bert_tokenizer.py b/keras_nlp/models/bert/bert_tokenizer.py
index 629e2a3c44..5bee999638 100644
--- a/keras_nlp/models/bert/bert_tokenizer.py
+++ b/keras_nlp/models/bert/bert_tokenizer.py
@@ -55,23 +55,23 @@ class BertTokenizer(WordPieceTokenizer):
     Examples:
 
     Batched input.
-    >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]"]
-    >>> vocab += ["The", "qu", "##ick", "brown", "fox", "."]
+    >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
+    >>> vocab += ["The", "quick", "brown", "fox.", "jumped", "over"]
     >>> inputs = ["The quick brown fox.", "The fox."]
     >>> tokenizer = keras_nlp.models.BertTokenizer(vocabulary=vocab)
     >>> tokenizer(inputs)
-    <tf.RaggedTensor [[4, 5, 6, 7, 8, 9], [4, 8, 9]]>
+    <tf.RaggedTensor [[5, 6, 7, 0, 0], [5, 0, 0]]>
 
     Unbatched input.
-    >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]"]
+    >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
     >>> vocab += ["The", "qu", "##ick", "brown", "fox", "."]
     >>> inputs = "The fox."
     >>> tokenizer = keras_nlp.models.BertTokenizer(vocabulary=vocab)
     >>> tokenizer(inputs)
-    <tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 8, 9], dtype=int32)>
+    <tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 5,  9, 10], dtype=int32)>
 
     Detokenization.
-    >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]"]
+    >>> vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
     >>> vocab += ["The", "qu", "##ick", "brown", "fox", "."]
     >>> inputs = "The quick brown fox."
     >>> tokenizer = keras_nlp.models.BertTokenizer(vocabulary=vocab)
@@ -95,6 +95,7 @@ def __init__(
         cls_token = "[CLS]"
         sep_token = "[SEP]"
         pad_token = "[PAD]"
+        mask_token = "[MASK]"
         for token in [cls_token, pad_token, sep_token]:
             if token not in self.get_vocabulary():
                 raise ValueError(
@@ -106,6 +107,7 @@ def __init__(
         self.cls_token_id = self.token_to_id(cls_token)
         self.sep_token_id = self.token_to_id(sep_token)
         self.pad_token_id = self.token_to_id(pad_token)
+        self.mask_token_id = self.token_to_id(mask_token)
 
     @classproperty
     def presets(cls):