From 51fb5fb69176dc59f21f463b1d122f8adda71100 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Sun, 5 Feb 2023 22:32:34 +0530 Subject: [PATCH 01/16] albert lm init commit --- keras_nlp/models/__init__.py | 4 + keras_nlp/models/albert/albert_masked_lm.py | 70 ++++++++ .../albert/albert_masked_lm_preprocessor.py | 89 ++++++++++ .../albert_masked_lm_preprocessor_test.py | 161 ++++++++++++++++++ .../models/albert/albert_masked_lm_test.py | 141 +++++++++++++++ keras_nlp/models/albert/albert_tokenizer.py | 2 + 6 files changed, 467 insertions(+) create mode 100644 keras_nlp/models/albert/albert_masked_lm.py create mode 100644 keras_nlp/models/albert/albert_masked_lm_preprocessor.py create mode 100644 keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py create mode 100644 keras_nlp/models/albert/albert_masked_lm_test.py diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index db70d43bce..e2cb3f323e 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -13,6 +13,10 @@ # limitations under the License. from keras_nlp.models.albert.albert_backbone import AlbertBackbone +from keras_nlp.models.albert.albert_masked_lm import AlbertMaskedLM +from keras_nlp.models.albert.albert_masked_lm_preprocessor import ( + AlbertMaskedLMPreprocessor, +) from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer from keras_nlp.models.bart.bart_backbone import BartBackbone diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py new file mode 100644 index 0000000000..7b46c0facb --- /dev/null +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -0,0 +1,70 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Albert masked lm model.""" + +import copy + +from tensorflow import keras + +from keras_nlp.layers.masked_lm_head import MaskedLMHead +from keras_nlp.models.albert.albert_backbone import AlbertBackbone +from keras_nlp.models.albert.albert_backbone import albert_kernel_initializer +from keras_nlp.models.albert.albert_masked_lm_preprocessor import ( + AlbertMaskedLMPreprocessor, +) +from keras_nlp.models.albert.albert_presets import backbone_presets +from keras_nlp.models.task import Task +from keras_nlp.utils.python_utils import classproperty + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class AlbertMaskedLM(Task): + def __init__(self, backbone, preprocessor=None, **kwargs): + inputs = { + **backbone.input, + "mask_positions": keras.Input( + shape=(None,), dtype="int32", name="mask_positions" + ), + } + + backbone_outputs = backbone(backbone.input) + outputs = MaskedLMHead( + vocabulary_size=backbone.vocabulary_size, + embedding_weights=backbone.token_embedding.embeddings, + intermediate_activation="gelu", + kernel_initializer=albert_kernel_initializer(), + name="mlm_head", + )(backbone_outputs, inputs["mask_positions"]) + + super().__init__( + inputs=inputs, + outputs=outputs, + include_preprocessing=preprocessor is not None, + **kwargs + ) + + self.backbone = backbone + self.preprocessor = preprocessor + + @classproperty + def backbone_cls(cls): + return AlbertBackbone + + @classproperty + def preprocessor_cls(cls): + return AlbertMaskedLMPreprocessor + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py new file mode 100644 index 0000000000..8e0a64469f --- /dev/null +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py @@ -0,0 +1,89 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Albert masked language model preprocessor layer""" + +from absl import logging +from tensorflow import keras + +from keras_nlp.layers.masked_lm_mask_generator import MaskedLMMaskGenerator +from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class AlbertMaskedLMPreprocessor(AlbertPreprocessor): + def __init__( + self, + tokenizer, + sequence_length=512, + truncate="round_robin", + mask_selection_rate=0.15, + mask_selection_length=96, + mask_token_rate=0.8, + random_token_rate=0.1, + **kwargs, + ): + super().__init__( + tokenizer, + sequence_length=sequence_length, + truncate=truncate, + **kwargs, + ) + + self.masker = MaskedLMMaskGenerator( + mask_selection_rate=mask_selection_rate, + mask_selection_length=mask_selection_length, + mask_token_rate=mask_token_rate, + random_token_rate=random_token_rate, + vocabulary_size=tokenizer.vocabulary_size(), + mask_token_id=tokenizer.mask_token_id, + unselectable_token_ids=[ + tokenizer.cls_token_id, + tokenizer.sep_token_id, + tokenizer.pad_token_id, + ], + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "mask_selection_rate": self.masker.mask_selection_rate, + "mask_selection_length": self.masker.mask_selection_length, + "mask_token_rate": self.masker.mask_token_rate, + "random_token_rate": self.masker.random_token_rate, + } + ) + return config + + def call(self, x, y=None, sample_weight=None): + if y is not None or sample_weight is not None: + logging.warning( + f"{self.__class__.__name__} generates `y` and `sample_weight` " + "based on your input data, but your data already contains `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "ignored." + ) + + x = super().call(x) + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + masker_outputs = self.masker(token_ids) + x = { + "token_ids": masker_outputs["token_ids"], + "padding_mask": padding_mask, + "mask_positions": masker_outputs["mask_positions"], + } + y = masker_outputs["mask_ids"] + sample_weight = masker_outputs["mask_weights"] + return pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py new file mode 100644 index 0000000000..f70b8d58b8 --- /dev/null +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py @@ -0,0 +1,161 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Albert masked language model preprocessor layer.""" + +import io +import os + +import sentencepiece +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.albert.albert_masked_lm_preprocessor import ( + AlbertMaskedLMPreprocessor, +) +from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer + + +class AlbertaMaskedLMPreprocessorTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + vocab_data = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round"] + ) + + bytes_io = io.BytesIO() + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=10, + model_type="WORD", + pad_id=0, + unk_id=1, + bos_id=2, + eos_id=3, + pad_piece="", + unk_piece="", + bos_piece="[CLS]", + eos_piece="[SEP]", + ) + + proto = bytes_io.getvalue() + + tokenizer = AlbertTokenizer(proto=proto) + + self.preprocessor = AlbertMaskedLMPreprocessor( + tokenizer=tokenizer, + # Simplify out testing by masking every available token. + mask_selection_rate=1.0, + mask_token_rate=1.0, + random_token_rate=0.0, + mask_selection_length=5, + sequence_length=12, + ) + + def test_preprocess_strings(self): + input_data = " airplane at airport" + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [0, 12, 12, 12, 12, 12, 2, 1, 1, 1, 1, 1] + ) + self.assertAllEqual( + x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] + ) + self.assertAllEqual(x["mask_positions"], [1, 2, 3, 4, 5]) + self.assertAllEqual(y, [3, 4, 5, 3, 6]) + self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0, 1.0]) + + def test_preprocess_list_of_strings(self): + input_data = [" airplane at airport"] * 4 + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [[0, 12, 12, 12, 12, 12, 2, 1, 1, 1, 1, 1]] * 4 + ) + self.assertAllEqual( + x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]] * 4 + ) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4, 5]] * 4) + self.assertAllEqual(y, [[3, 4, 5, 3, 6]] * 4) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0, 1.0]] * 4) + + def test_preprocess_dataset(self): + sentences = tf.constant([" airplane at airport"] * 4) + ds = tf.data.Dataset.from_tensor_slices(sentences) + ds = ds.map(self.preprocessor) + x, y, sw = ds.batch(4).take(1).get_single_element() + self.assertAllEqual( + x["token_ids"], [[0, 12, 12, 12, 12, 12, 2, 1, 1, 1, 1, 1]] * 4 + ) + self.assertAllEqual( + x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]] * 4 + ) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4, 5]] * 4) + self.assertAllEqual(y, [[3, 4, 5, 3, 6]] * 4) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0, 1.0]] * 4) + + def test_mask_multiple_sentences(self): + sentence_one = tf.constant(" airplane") + sentence_two = tf.constant(" kohli") + + x, y, sw = self.preprocessor((sentence_one, sentence_two)) + self.assertAllEqual( + x["token_ids"], [0, 12, 12, 2, 2, 12, 12, 2, 1, 1, 1, 1] + ) + self.assertAllEqual( + x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0] + ) + self.assertAllEqual(x["mask_positions"], [1, 2, 5, 6, 0]) + self.assertAllEqual(y, [3, 4, 7, 8, 0]) + self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0, 0.0]) + + def test_no_masking_zero_rate(self): + no_mask_preprocessor = AlbertMaskedLMPreprocessor( + self.preprocessor.tokenizer, + mask_selection_rate=0.0, + mask_selection_length=5, + sequence_length=12, + ) + input_data = " airplane at airport" + + x, y, sw = no_mask_preprocessor(input_data) + self.assertAllEqual( + x["token_ids"], [0, 3, 4, 5, 3, 6, 2, 1, 1, 1, 1, 1] + ) + self.assertAllEqual( + x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] + ) + self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0, 0]) + self.assertAllEqual(y, [0, 0, 0, 0, 0]) + self.assertAllEqual(sw, [0.0, 0.0, 0.0, 0.0, 0.0]) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + input_data = tf.constant([" airplane at airport"]) + + inputs = keras.Input(dtype="string", shape=()) + outputs = self.preprocessor(inputs) + model = keras.Model(inputs, outputs) + + path = os.path.join(self.get_temp_dir(), filename) + model.save(path, save_format=save_format) + + restored_model = keras.models.load_model(path) + outputs = model(input_data)[0]["token_ids"] + restored_outputs = restored_model(input_data)[0]["token_ids"] + self.assertAllEqual(outputs, restored_outputs) diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py new file mode 100644 index 0000000000..133adf4aa7 --- /dev/null +++ b/keras_nlp/models/albert/albert_masked_lm_test.py @@ -0,0 +1,141 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Albert masked language model.""" + +import os + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.albert.albert_backbone import AlbertBackbone +from keras_nlp.models.albert.albert_masked_lm import AlbertMaskedLM +from keras_nlp.models.albert.albert_masked_lm_preprocessor import ( + AlbertMaskedLMPreprocessor, +) +from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer + + +class AlbertMaskedLMTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + self.backbone = AlbertBackbone( + vocabulary_size=1000, + num_layers=2, + num_heads=2, + hidden_dim=64, + intermediate_dim=128, + max_sequence_length=128, + ) + self.vocab = { + "": 0, + "": 1, + "": 2, + "Ġair": 3, + "plane": 4, + "Ġat": 5, + "port": 6, + "Ġkoh": 7, + "li": 8, + "Ġis": 9, + "Ġthe": 10, + "Ġbest": 11, + "": 12, + } + + merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "Ġa i", "p l", "n e"] + merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"] + merges += ["Ġt h", "Ġai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"] + merges += ["pla ne"] + self.merges = merges + self.preprocessor = AlbertMaskedLMPreprocessor( + AlbertTokenizer(vocabulary=self.vocab, merges=self.merges), + sequence_length=8, + mask_selection_length=2, + ) + self.masked_lm = AlbertMaskedLM( + self.backbone, + preprocessor=self.preprocessor, + ) + self.masked_lm_no_preprocessing = AlbertMaskedLM( + self.backbone, + preprocessor=None, + ) + + self.raw_batch = tf.constant( + [ + " airplane at airport", + " the airplane is the best", + " the best airport", + " kohli is the best", + ] + ) + self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] + self.raw_dataset = tf.data.Dataset.from_tensor_slices( + self.raw_batch + ).batch(2) + self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) + + def test_valid_call_masked_lm(self): + self.masked_lm(self.preprocessed_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_predict(self, jit_compile): + self.masked_lm.compile(jit_compile=jit_compile) + self.masked_lm.predict(self.raw_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_predict_no_preprocessing(self, jit_compile): + self.masked_lm_no_preprocessing.compile(jit_compile=jit_compile) + self.masked_lm_no_preprocessing.predict(self.preprocessed_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_fit(self, jit_compile): + self.masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.masked_lm.fit(self.raw_dataset) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_fit_no_preprocessing(self, jit_compile): + self.masked_lm_no_preprocessing.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.masked_lm_no_preprocessing.fit(self.preprocessed_dataset) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + save_path = os.path.join(self.get_temp_dir(), filename) + self.masked_lm.save(save_path, save_format=save_format) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, AlbertMaskedLM) + + model_output = self.masked_lm(self.preprocessed_batch) + restored_output = restored_model(self.preprocessed_batch) + + self.assertAllClose(model_output, restored_output) diff --git a/keras_nlp/models/albert/albert_tokenizer.py b/keras_nlp/models/albert/albert_tokenizer.py index 35dc7ec8d5..cf87a818e7 100644 --- a/keras_nlp/models/albert/albert_tokenizer.py +++ b/keras_nlp/models/albert/albert_tokenizer.py @@ -72,6 +72,7 @@ def __init__(self, proto, **kwargs): cls_token = "[CLS]" sep_token = "[SEP]" pad_token = "" + mask_token = "" for token in [cls_token, sep_token, pad_token]: if token not in self.get_vocabulary(): raise ValueError( @@ -83,6 +84,7 @@ def __init__(self, proto, **kwargs): self.cls_token_id = self.token_to_id(cls_token) self.sep_token_id = self.token_to_id(sep_token) self.pad_token_id = self.token_to_id(pad_token) + self.mask_token_id = self.token_to_id(mask_token) @classproperty def presets(cls): From 9eb6ff4250d94814cde3f5a5cdca0f378c3a88ae Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 8 Feb 2023 22:02:22 +0530 Subject: [PATCH 02/16] fixing preprocessor tests --- keras_nlp/models/albert/albert_masked_lm.py | 4 +- .../albert_masked_lm_preprocessor_test.py | 44 ++--- .../models/albert/albert_masked_lm_test.py | 158 ++++++++++-------- 3 files changed, 109 insertions(+), 97 deletions(-) diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py index 7b46c0facb..db665a5b0b 100644 --- a/keras_nlp/models/albert/albert_masked_lm.py +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -42,7 +42,9 @@ def __init__(self, backbone, preprocessor=None, **kwargs): outputs = MaskedLMHead( vocabulary_size=backbone.vocabulary_size, embedding_weights=backbone.token_embedding.embeddings, - intermediate_activation="gelu", + intermediate_activation=lambda x: keras.activations.gelu( + x, approximate=True + ), kernel_initializer=albert_kernel_initializer(), name="mlm_head", )(backbone_outputs, inputs["mask_positions"]) diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py index f70b8d58b8..72a8c2d481 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py @@ -68,28 +68,28 @@ def test_preprocess_strings(self): x, y, sw = self.preprocessor(input_data) self.assertAllEqual( - x["token_ids"], [0, 12, 12, 12, 12, 12, 2, 1, 1, 1, 1, 1] + x["token_ids"], [2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual( - x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] + x["padding_mask"], [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] ) - self.assertAllEqual(x["mask_positions"], [1, 2, 3, 4, 5]) - self.assertAllEqual(y, [3, 4, 5, 3, 6]) - self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0, 1.0]) + self.assertAllEqual(x["mask_positions"], [1, 0, 0, 0, 0]) + self.assertAllEqual(y, [1, 0, 0, 0, 0]) + self.assertAllEqual(sw, [1.0, 0.0, 0.0, 0.0, 0.0]) def test_preprocess_list_of_strings(self): input_data = [" airplane at airport"] * 4 x, y, sw = self.preprocessor(input_data) self.assertAllEqual( - x["token_ids"], [[0, 12, 12, 12, 12, 12, 2, 1, 1, 1, 1, 1]] * 4 + x["token_ids"], [[2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 ) self.assertAllEqual( - x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]] * 4 + x["padding_mask"], [[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 ) - self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4, 5]] * 4) - self.assertAllEqual(y, [[3, 4, 5, 3, 6]] * 4) - self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0, 1.0]] * 4) + self.assertAllEqual(x["mask_positions"], [[1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1.0, 0.0, 0.0, 0.0, 0.0]] * 4) def test_preprocess_dataset(self): sentences = tf.constant([" airplane at airport"] * 4) @@ -97,14 +97,14 @@ def test_preprocess_dataset(self): ds = ds.map(self.preprocessor) x, y, sw = ds.batch(4).take(1).get_single_element() self.assertAllEqual( - x["token_ids"], [[0, 12, 12, 12, 12, 12, 2, 1, 1, 1, 1, 1]] * 4 + x["token_ids"], [[2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 ) self.assertAllEqual( - x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]] * 4 + x["padding_mask"], [[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 ) - self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4, 5]] * 4) - self.assertAllEqual(y, [[3, 4, 5, 3, 6]] * 4) - self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0, 1.0]] * 4) + self.assertAllEqual(x["mask_positions"], [[1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1.0, 0.0, 0.0, 0.0, 0.0]] * 4) def test_mask_multiple_sentences(self): sentence_one = tf.constant(" airplane") @@ -112,14 +112,14 @@ def test_mask_multiple_sentences(self): x, y, sw = self.preprocessor((sentence_one, sentence_two)) self.assertAllEqual( - x["token_ids"], [0, 12, 12, 2, 2, 12, 12, 2, 1, 1, 1, 1] + x["token_ids"], [2, 1, 3, 1, 3, 0, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual( - x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0] + x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] ) - self.assertAllEqual(x["mask_positions"], [1, 2, 5, 6, 0]) - self.assertAllEqual(y, [3, 4, 7, 8, 0]) - self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0, 0.0]) + self.assertAllEqual(x["mask_positions"], [1, 3, 0, 0, 0]) + self.assertAllEqual(y, [1, 1, 0, 0, 0]) + self.assertAllEqual(sw, [1.0, 1.0, 0.0, 0.0, 0.0]) def test_no_masking_zero_rate(self): no_mask_preprocessor = AlbertMaskedLMPreprocessor( @@ -132,10 +132,10 @@ def test_no_masking_zero_rate(self): x, y, sw = no_mask_preprocessor(input_data) self.assertAllEqual( - x["token_ids"], [0, 3, 4, 5, 3, 6, 2, 1, 1, 1, 1, 1] + x["token_ids"], [2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual( - x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] + x["padding_mask"], [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0, 0]) self.assertAllEqual(y, [0, 0, 0, 0, 0]) diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py index 133adf4aa7..93c9dcacdf 100644 --- a/keras_nlp/models/albert/albert_masked_lm_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_test.py @@ -14,7 +14,9 @@ """Tests for Albert masked language model.""" import os +import io +import sentencepiece import tensorflow as tf from absl.testing import parameterized from tensorflow import keras @@ -33,35 +35,43 @@ def setUp(self): vocabulary_size=1000, num_layers=2, num_heads=2, + embedding_dim=8, hidden_dim=64, intermediate_dim=128, max_sequence_length=128, ) - self.vocab = { - "": 0, - "": 1, - "": 2, - "Ġair": 3, - "plane": 4, - "Ġat": 5, - "port": 6, - "Ġkoh": 7, - "li": 8, - "Ġis": 9, - "Ġthe": 10, - "Ġbest": 11, - "": 12, - } - - merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "Ġa i", "p l", "n e"] - merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"] - merges += ["Ġt h", "Ġai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"] - merges += ["pla ne"] - self.merges = merges + vocab_data = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round"] + ) + + bytes_io = io.BytesIO() + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=10, + model_type="WORD", + pad_id=0, + unk_id=1, + bos_id=2, + eos_id=3, + pad_piece="", + unk_piece="", + bos_piece="[CLS]", + eos_piece="[SEP]", + ) + + proto = bytes_io.getvalue() + + tokenizer = AlbertTokenizer(proto=proto) + self.preprocessor = AlbertMaskedLMPreprocessor( - AlbertTokenizer(vocabulary=self.vocab, merges=self.merges), - sequence_length=8, - mask_selection_length=2, + tokenizer=tokenizer, + # Simplify out testing by masking every available token. + mask_selection_rate=1.0, + mask_token_rate=1.0, + random_token_rate=0.0, + mask_selection_length=5, + sequence_length=12, ) self.masked_lm = AlbertMaskedLM( self.backbone, @@ -89,53 +99,53 @@ def setUp(self): def test_valid_call_masked_lm(self): self.masked_lm(self.preprocessed_batch) - @parameterized.named_parameters( - ("jit_compile_false", False), ("jit_compile_true", True) - ) - def test_albert_masked_lm_predict(self, jit_compile): - self.masked_lm.compile(jit_compile=jit_compile) - self.masked_lm.predict(self.raw_batch) - - @parameterized.named_parameters( - ("jit_compile_false", False), ("jit_compile_true", True) - ) - def test_albert_masked_lm_predict_no_preprocessing(self, jit_compile): - self.masked_lm_no_preprocessing.compile(jit_compile=jit_compile) - self.masked_lm_no_preprocessing.predict(self.preprocessed_batch) - - @parameterized.named_parameters( - ("jit_compile_false", False), ("jit_compile_true", True) - ) - def test_albert_masked_lm_fit(self, jit_compile): - self.masked_lm.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - jit_compile=jit_compile, - ) - self.masked_lm.fit(self.raw_dataset) - - @parameterized.named_parameters( - ("jit_compile_false", False), ("jit_compile_true", True) - ) - def test_albert_masked_lm_fit_no_preprocessing(self, jit_compile): - self.masked_lm_no_preprocessing.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - jit_compile=jit_compile, - ) - self.masked_lm_no_preprocessing.fit(self.preprocessed_dataset) - - @parameterized.named_parameters( - ("tf_format", "tf", "model"), - ("keras_format", "keras_v3", "model.keras"), - ) - def test_saved_model(self, save_format, filename): - save_path = os.path.join(self.get_temp_dir(), filename) - self.masked_lm.save(save_path, save_format=save_format) - restored_model = keras.models.load_model(save_path) - - # Check we got the real object back. - self.assertIsInstance(restored_model, AlbertMaskedLM) - - model_output = self.masked_lm(self.preprocessed_batch) - restored_output = restored_model(self.preprocessed_batch) - - self.assertAllClose(model_output, restored_output) + # @parameterized.named_parameters( + # ("jit_compile_false", False), ("jit_compile_true", True) + # ) + # def test_albert_masked_lm_predict(self, jit_compile): + # self.masked_lm.compile(jit_compile=jit_compile) + # self.masked_lm.predict(self.raw_batch) + + # @parameterized.named_parameters( + # ("jit_compile_false", False), ("jit_compile_true", True) + # ) + # def test_albert_masked_lm_predict_no_preprocessing(self, jit_compile): + # self.masked_lm_no_preprocessing.compile(jit_compile=jit_compile) + # self.masked_lm_no_preprocessing.predict(self.preprocessed_batch) + + # @parameterized.named_parameters( + # ("jit_compile_false", False), ("jit_compile_true", True) + # ) + # def test_albert_masked_lm_fit(self, jit_compile): + # self.masked_lm.compile( + # loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + # jit_compile=jit_compile, + # ) + # self.masked_lm.fit(self.raw_dataset) + + # @parameterized.named_parameters( + # ("jit_compile_false", False), ("jit_compile_true", True) + # ) + # def test_albert_masked_lm_fit_no_preprocessing(self, jit_compile): + # self.masked_lm_no_preprocessing.compile( + # loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + # jit_compile=jit_compile, + # ) + # self.masked_lm_no_preprocessing.fit(self.preprocessed_dataset) + + # @parameterized.named_parameters( + # ("tf_format", "tf", "model"), + # ("keras_format", "keras_v3", "model.keras"), + # ) + # def test_saved_model(self, save_format, filename): + # save_path = os.path.join(self.get_temp_dir(), filename) + # self.masked_lm.save(save_path, save_format=save_format) + # restored_model = keras.models.load_model(save_path) + + # # Check we got the real object back. + # self.assertIsInstance(restored_model, AlbertMaskedLM) + + # model_output = self.masked_lm(self.preprocessed_batch) + # restored_output = restored_model(self.preprocessed_batch) + + # self.assertAllClose(model_output, restored_output) From 7ee4bd65a21830cbd93a630cb0264d9bbaf95729 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 8 Feb 2023 23:17:26 +0530 Subject: [PATCH 03/16] fixing the main model test + formatting + docstrings --- keras_nlp/models/albert/albert_masked_lm.py | 85 ++++++++++++- .../albert/albert_masked_lm_preprocessor.py | 114 +++++++++++++++++- .../models/albert/albert_masked_lm_test.py | 104 ++++++++-------- 3 files changed, 247 insertions(+), 56 deletions(-) diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py index db665a5b0b..f4bbc9498d 100644 --- a/keras_nlp/models/albert/albert_masked_lm.py +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Albert masked lm model.""" import copy @@ -30,6 +31,84 @@ @keras.utils.register_keras_serializable(package="keras_nlp") class AlbertMaskedLM(Task): + """An end-to-end Albert model for the masked language modeling task. + + This model will train Albert on a masked language modeling task. + The model will predict labels for a number of masked tokens in the + input data. For usage of this model with pre-trained weights, see the + `from_preset()` method. + + This model can optionally be configured with a `preprocessor` layer, in + which case inputs can be raw string features during `fit()`, `predict()`, + and `evaluate()`. Inputs will be tokenized and dynamically masked during + training and evaluation. This is done by default when creating the model + with `from_preset()`. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. The underlying model is provided by a + third party and subject to a separate license, available + [here](https://github.com/facebookresearch/fairseq). + + Args: + backbone: A `keras_nlp.models.AlbertBackbone` instance. + preprocessor: A `keras_nlp.models.AlbertMaskedLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + + Example usage: + + Raw string inputs and pretrained backbone. + ```python + # Create a dataset with raw string features. Labels are inferred. + features = ["The quick brown fox jumped.", "I forgot my homework."] + + # Create a AlbertaskedLM with a pretrained backbone and further train + # on an MLM task. + masked_lm = keras_nlp.models.AlbertMaskedLM.from_preset( + "albert_base_en_uncased", + ) + masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + ) + masked_lm.fit(x=features, batch_size=2) + ``` + + Preprocessed inputs and custom backbone. + ```python + # Create a preprocessed dataset where 0 is the mask token. + preprocessed_features = { + "token_ids": tf.constant( + [[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8) + ), + "padding_mask": tf.constant( + [[1, 1, 1, 1, 1, 1, 1, 1]] * 2, shape=(2, 8) + ), + "mask_positions": tf.constant([[2, 4]] * 2, shape=(2, 2)) + } + # Labels are the original masked values. + labels = [[3, 5]] * 2 + + # Randomly initialize a Albert encoder + backbone = keras_nlp.models.AlbertBackbone( + vocabulary_size=50265, + num_layers=12, + num_heads=12, + hidden_dim=768, + intermediate_dim=3072, + max_sequence_length=12 + ) + # Create a Albert masked_lm and fit the data. + masked_lm = keras_nlp.models.AlbertMaskedLM( + backbone, + preprocessor=None, + ) + masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + ) + masked_lm.fit(x=preprocessed_features, y=labels, batch_size=2) + ``` + """ + def __init__(self, backbone, preprocessor=None, **kwargs): inputs = { **backbone.input, @@ -43,11 +122,11 @@ def __init__(self, backbone, preprocessor=None, **kwargs): vocabulary_size=backbone.vocabulary_size, embedding_weights=backbone.token_embedding.embeddings, intermediate_activation=lambda x: keras.activations.gelu( - x, approximate=True - ), + x, approximate=True + ), kernel_initializer=albert_kernel_initializer(), name="mlm_head", - )(backbone_outputs, inputs["mask_positions"]) + )(backbone_outputs["sequence_output"], inputs["mask_positions"]) super().__init__( inputs=inputs, diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py index 8e0a64469f..e85cff69e7 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Albert masked language model preprocessor layer""" from absl import logging @@ -23,6 +24,112 @@ @keras.utils.register_keras_serializable(package="keras_nlp") class AlbertMaskedLMPreprocessor(AlbertPreprocessor): + """Albert preprocessing for the masked language modeling task. + + This preprocessing layer will prepare inputs for a masked language modeling + task. It is primarily intended for use with the + `keras_nlp.models.AlbertMaskedLM` task model. Preprocessing will occur in + multiple steps. + + - Tokenize any number of input segments using the `tokenizer`. + - Pack the inputs together with the appropriate `""`, `""` and + `""` tokens, i.e., adding a single `""` at the start of the + entire sequence, `""` between each segment, + and a `""` at the end of the entire sequence. + - Randomly select non-special tokens to mask, controlled by + `mask_selection_rate`. + - Construct a `(x, y, sample_weight)` tuple suitable for training with a + `keras_nlp.models.AlbertMaskedLM` task model. + + Args: + tokenizer: A `keras_nlp.models.AlbertTokenizer` instance. + sequence_length: The length of the packed inputs. + mask_selection_rate: The probability an input token will be dynamically + masked. + mask_selection_length: The maximum number of masked tokens supported + by the layer. + mask_token_rate: float, defaults to 0.8. `mask_token_rate` must be + between 0 and 1 which indicates how often the mask_token is + substituted for tokens selected for masking. + random_token_rate: float, defaults to 0.1. `random_token_rate` must be + between 0 and 1 which indicates how often a random token is + substituted for tokens selected for masking. Default is 0.1. + Note: mask_token_rate + random_token_rate <= 1, and for + (1 - mask_token_rate - random_token_rate), the token will not be + changed. + truncate: string. The algorithm to truncate a list of batched segments + to fit within `sequence_length`. The value can be either + `round_robin` or `waterfall`: + - `"round_robin"`: Available space is assigned one token at a + time in a round-robin fashion to the inputs that still need + some, until the limit is reached. + - `"waterfall"`: The allocation of the budget is done using a + "waterfall" algorithm that allocates quota in a + left-to-right manner and fills up the buckets until we run + out of budget. It supports an arbitrary number of segments. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.AlbertMaskedLMPreprocessor.from_preset( + "albert_base_en_uncased" + ) + + # Tokenize and mask a single sentence. + sentence = tf.constant("The quick brown fox jumped.") + preprocessor(sentence) + + # Tokenize and mask a batch of sentences. + sentences = tf.constant( + ["The quick brown fox jumped.", "Call me Ishmael."] + ) + preprocessor(sentences) + + # Tokenize and mask a dataset of sentences. + features = tf.constant( + ["The quick brown fox jumped.", "Call me Ishmael."] + ) + ds = tf.data.Dataset.from_tensor_slices((features)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Alternatively, you can create a preprocessor from your own vocabulary. + tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round"] + ) + + bytes_io = io.BytesIO() + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=10, + model_type="WORD", + pad_id=0, + unk_id=1, + bos_id=2, + eos_id=3, + pad_piece="", + unk_piece="", + bos_piece="[CLS]", + eos_piece="[SEP]", + ) + + proto = bytes_io.getvalue() + + tokenizer = AlbertTokenizer(proto=proto) + + preprocessor = AlbertMaskedLMPreprocessor( + tokenizer=tokenizer, + # Simplify out testing by masking every available token. + mask_selection_rate=1.0, + mask_token_rate=1.0, + random_token_rate=0.0, + mask_selection_length=5, + sequence_length=12, + ) + + ``` + """ + def __init__( self, tokenizer, @@ -77,10 +184,15 @@ def call(self, x, y=None, sample_weight=None): ) x = super().call(x) - token_ids, padding_mask = x["token_ids"], x["padding_mask"] + token_ids, segment_ids, padding_mask = ( + x["token_ids"], + x["segment_ids"], + x["padding_mask"], + ) masker_outputs = self.masker(token_ids) x = { "token_ids": masker_outputs["token_ids"], + "segment_ids": segment_ids, "padding_mask": padding_mask, "mask_positions": masker_outputs["mask_positions"], } diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py index 93c9dcacdf..9a93d0dc98 100644 --- a/keras_nlp/models/albert/albert_masked_lm_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_test.py @@ -13,8 +13,8 @@ # limitations under the License. """Tests for Albert masked language model.""" -import os import io +import os import sentencepiece import tensorflow as tf @@ -35,7 +35,7 @@ def setUp(self): vocabulary_size=1000, num_layers=2, num_heads=2, - embedding_dim=8, + embedding_dim=64, hidden_dim=64, intermediate_dim=128, max_sequence_length=128, @@ -99,53 +99,53 @@ def setUp(self): def test_valid_call_masked_lm(self): self.masked_lm(self.preprocessed_batch) - # @parameterized.named_parameters( - # ("jit_compile_false", False), ("jit_compile_true", True) - # ) - # def test_albert_masked_lm_predict(self, jit_compile): - # self.masked_lm.compile(jit_compile=jit_compile) - # self.masked_lm.predict(self.raw_batch) - - # @parameterized.named_parameters( - # ("jit_compile_false", False), ("jit_compile_true", True) - # ) - # def test_albert_masked_lm_predict_no_preprocessing(self, jit_compile): - # self.masked_lm_no_preprocessing.compile(jit_compile=jit_compile) - # self.masked_lm_no_preprocessing.predict(self.preprocessed_batch) - - # @parameterized.named_parameters( - # ("jit_compile_false", False), ("jit_compile_true", True) - # ) - # def test_albert_masked_lm_fit(self, jit_compile): - # self.masked_lm.compile( - # loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - # jit_compile=jit_compile, - # ) - # self.masked_lm.fit(self.raw_dataset) - - # @parameterized.named_parameters( - # ("jit_compile_false", False), ("jit_compile_true", True) - # ) - # def test_albert_masked_lm_fit_no_preprocessing(self, jit_compile): - # self.masked_lm_no_preprocessing.compile( - # loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - # jit_compile=jit_compile, - # ) - # self.masked_lm_no_preprocessing.fit(self.preprocessed_dataset) - - # @parameterized.named_parameters( - # ("tf_format", "tf", "model"), - # ("keras_format", "keras_v3", "model.keras"), - # ) - # def test_saved_model(self, save_format, filename): - # save_path = os.path.join(self.get_temp_dir(), filename) - # self.masked_lm.save(save_path, save_format=save_format) - # restored_model = keras.models.load_model(save_path) - - # # Check we got the real object back. - # self.assertIsInstance(restored_model, AlbertMaskedLM) - - # model_output = self.masked_lm(self.preprocessed_batch) - # restored_output = restored_model(self.preprocessed_batch) - - # self.assertAllClose(model_output, restored_output) + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_predict(self, jit_compile): + self.masked_lm.compile(jit_compile=jit_compile) + self.masked_lm.predict(self.raw_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_predict_no_preprocessing(self, jit_compile): + self.masked_lm_no_preprocessing.compile(jit_compile=jit_compile) + self.masked_lm_no_preprocessing.predict(self.preprocessed_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_fit(self, jit_compile): + self.masked_lm.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.masked_lm.fit(self.raw_dataset) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_albert_masked_lm_fit_no_preprocessing(self, jit_compile): + self.masked_lm_no_preprocessing.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=jit_compile, + ) + self.masked_lm_no_preprocessing.fit(self.preprocessed_dataset) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + save_path = os.path.join(self.get_temp_dir(), filename) + self.masked_lm.save(save_path, save_format=save_format) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, AlbertMaskedLM) + + model_output = self.masked_lm(self.preprocessed_batch) + restored_output = restored_model(self.preprocessed_batch) + + self.assertAllClose(model_output, restored_output) From 6138b0453b763771b15415cb6fcf78ce3d02b442 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Fri, 10 Feb 2023 00:00:31 +0530 Subject: [PATCH 04/16] fixing bug in masked lm head --- keras_nlp/layers/masked_lm_head.py | 6 ++-- keras_nlp/models/albert/albert_masked_lm.py | 33 ++++++++++--------- .../albert/albert_masked_lm_preprocessor.py | 6 ++-- .../models/albert/albert_masked_lm_test.py | 2 +- 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/keras_nlp/layers/masked_lm_head.py b/keras_nlp/layers/masked_lm_head.py index 2a31aa7c91..5ecc592ea0 100644 --- a/keras_nlp/layers/masked_lm_head.py +++ b/keras_nlp/layers/masked_lm_head.py @@ -140,10 +140,10 @@ def __init__( self.vocabulary_size = shape[0] def build(self, input_shapes): - feature_size = input_shapes[-1] + embedding_dim = self.embedding_weights.shape[-1] self._dense = keras.layers.Dense( - feature_size, + embedding_dim, activation=self.intermediate_activation, kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer, @@ -154,7 +154,7 @@ def build(self, input_shapes): if self.embedding_weights is None: self._kernel = self.add_weight( name="output_kernel", - shape=[feature_size, self.vocabulary_size], + shape=[embedding_dim, self.vocabulary_size], initializer=self.kernel_initializer, dtype=self.dtype, ) diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py index f4bbc9498d..b6c53daf45 100644 --- a/keras_nlp/models/albert/albert_masked_lm.py +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Albert masked lm model.""" +"""ALBERT masked LM model.""" import copy @@ -31,9 +31,9 @@ @keras.utils.register_keras_serializable(package="keras_nlp") class AlbertMaskedLM(Task): - """An end-to-end Albert model for the masked language modeling task. + """An end-to-end ALBERT model for the masked language modeling task. - This model will train Albert on a masked language modeling task. + This model will train ALBERT on a masked language modeling task. The model will predict labels for a number of masked tokens in the input data. For usage of this model with pre-trained weights, see the `from_preset()` method. @@ -45,9 +45,7 @@ class AlbertMaskedLM(Task): with `from_preset()`. Disclaimer: Pre-trained models are provided on an "as is" basis, without - warranties or conditions of any kind. The underlying model is provided by a - third party and subject to a separate license, available - [here](https://github.com/facebookresearch/fairseq). + warranties or conditions of any kind. Args: backbone: A `keras_nlp.models.AlbertBackbone` instance. @@ -62,7 +60,7 @@ class AlbertMaskedLM(Task): # Create a dataset with raw string features. Labels are inferred. features = ["The quick brown fox jumped.", "I forgot my homework."] - # Create a AlbertaskedLM with a pretrained backbone and further train + # Create a AlbertMaskedLM with a pretrained backbone and further train # on an MLM task. masked_lm = keras_nlp.models.AlbertMaskedLM.from_preset( "albert_base_en_uncased", @@ -77,6 +75,9 @@ class AlbertMaskedLM(Task): ```python # Create a preprocessed dataset where 0 is the mask token. preprocessed_features = { + "segment_ids": tf.constant( + [[1, 0, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8) + ), "token_ids": tf.constant( [[1, 2, 0, 4, 0, 6, 7, 8]] * 2, shape=(2, 8) ), @@ -88,22 +89,24 @@ class AlbertMaskedLM(Task): # Labels are the original masked values. labels = [[3, 5]] * 2 - # Randomly initialize a Albert encoder + # Randomly initialize a ALBERT encoder backbone = keras_nlp.models.AlbertBackbone( - vocabulary_size=50265, - num_layers=12, - num_heads=12, - hidden_dim=768, - intermediate_dim=3072, - max_sequence_length=12 + vocabulary_size=1000, + num_layers=2, + num_heads=2, + embedding_dim=64, + hidden_dim=64, + intermediate_dim=128, + max_sequence_length=128, ) - # Create a Albert masked_lm and fit the data. + # Create a ALBERT masked LM and fit the data. masked_lm = keras_nlp.models.AlbertMaskedLM( backbone, preprocessor=None, ) masked_lm.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + jit_compile=True ) masked_lm.fit(x=preprocessed_features, y=labels, batch_size=2) ``` diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py index e85cff69e7..180c301c27 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Albert masked language model preprocessor layer""" +"""ALBERT masked language model preprocessor layer.""" from absl import logging from tensorflow import keras @@ -24,7 +24,7 @@ @keras.utils.register_keras_serializable(package="keras_nlp") class AlbertMaskedLMPreprocessor(AlbertPreprocessor): - """Albert preprocessing for the masked language modeling task. + """ALBERT preprocessing for the masked language modeling task. This preprocessing layer will prepare inputs for a masked language modeling task. It is primarily intended for use with the @@ -93,7 +93,7 @@ class AlbertMaskedLMPreprocessor(AlbertPreprocessor): ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) # Alternatively, you can create a preprocessor from your own vocabulary. - tf.data.Dataset.from_tensor_slices( + vocab_data = tf.data.Dataset.from_tensor_slices( ["the quick brown fox", "the earth is round"] ) diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py index 9a93d0dc98..93ccc16c8d 100644 --- a/keras_nlp/models/albert/albert_masked_lm_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_test.py @@ -35,7 +35,7 @@ def setUp(self): vocabulary_size=1000, num_layers=2, num_heads=2, - embedding_dim=64, + embedding_dim=128, hidden_dim=64, intermediate_dim=128, max_sequence_length=128, From 4dd31f727646a375507d20cf9c27360e4f90a2b7 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Fri, 10 Feb 2023 00:45:44 +0530 Subject: [PATCH 05/16] fixing none condition in masked_lm_head_test --- keras_nlp/layers/masked_lm_head.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/keras_nlp/layers/masked_lm_head.py b/keras_nlp/layers/masked_lm_head.py index 5ecc592ea0..9291d3d1b4 100644 --- a/keras_nlp/layers/masked_lm_head.py +++ b/keras_nlp/layers/masked_lm_head.py @@ -140,10 +140,13 @@ def __init__( self.vocabulary_size = shape[0] def build(self, input_shapes): - embedding_dim = self.embedding_weights.shape[-1] + feature_size = input_shapes[-1] + + if self.embedding_weights is not None: + feature_size = self.embedding_weights.shape[-1] self._dense = keras.layers.Dense( - embedding_dim, + feature_size, activation=self.intermediate_activation, kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer, @@ -154,7 +157,7 @@ def build(self, input_shapes): if self.embedding_weights is None: self._kernel = self.add_weight( name="output_kernel", - shape=[embedding_dim, self.vocabulary_size], + shape=[feature_size, self.vocabulary_size], initializer=self.kernel_initializer, dtype=self.dtype, ) From ae25305bd8ddefd4fe17092f4ee3006edfc352a5 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Fri, 10 Feb 2023 00:54:12 +0530 Subject: [PATCH 06/16] fixing formatting --- keras_nlp/layers/masked_lm_head.py | 2 +- keras_nlp/models/albert/albert_masked_lm_preprocessor.py | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/keras_nlp/layers/masked_lm_head.py b/keras_nlp/layers/masked_lm_head.py index 9291d3d1b4..c0d6e18dd9 100644 --- a/keras_nlp/layers/masked_lm_head.py +++ b/keras_nlp/layers/masked_lm_head.py @@ -141,7 +141,7 @@ def __init__( def build(self, input_shapes): feature_size = input_shapes[-1] - + if self.embedding_weights is not None: feature_size = self.embedding_weights.shape[-1] diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py index 180c301c27..bf617db852 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py @@ -118,13 +118,7 @@ class AlbertMaskedLMPreprocessor(AlbertPreprocessor): tokenizer = AlbertTokenizer(proto=proto) preprocessor = AlbertMaskedLMPreprocessor( - tokenizer=tokenizer, - # Simplify out testing by masking every available token. - mask_selection_rate=1.0, - mask_token_rate=1.0, - random_token_rate=0.0, - mask_selection_length=5, - sequence_length=12, + tokenizer=tokenizer ) ``` From 27be51941bfa2cb3dc594cc0598e7a42fb7654ba Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Fri, 10 Feb 2023 01:02:20 +0530 Subject: [PATCH 07/16] fixing test_valid_call_with_embedding_weights --- keras_nlp/layers/masked_lm_head_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/layers/masked_lm_head_test.py b/keras_nlp/layers/masked_lm_head_test.py index 15734c88c8..591a24ed1e 100644 --- a/keras_nlp/layers/masked_lm_head_test.py +++ b/keras_nlp/layers/masked_lm_head_test.py @@ -46,11 +46,11 @@ def test_valid_call_with_embedding_weights(self): embedding_weights=embedding.embeddings, activation="softmax", ) - encoded_tokens = keras.Input(shape=(10, 16)) + encoded_tokens = keras.Input(shape=(32, 16)) positions = keras.Input(shape=(5,), dtype="int32") outputs = head(encoded_tokens, mask_positions=positions) model = keras.Model((encoded_tokens, positions), outputs) - token_data = tf.random.uniform(shape=(4, 10, 16)) + token_data = tf.random.uniform(shape=(4, 32, 16)) position_data = tf.random.uniform( shape=(4, 5), maxval=10, dtype="int32" ) From e7287b85e2a74b255ee1f03c1ba22e0d225364fc Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Fri, 10 Feb 2023 14:14:51 +0530 Subject: [PATCH 08/16] minor docstring changes --- keras_nlp/models/albert/albert_masked_lm.py | 4 ++-- keras_nlp/models/albert/albert_masked_lm_preprocessor.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py index b6c53daf45..2fe248856c 100644 --- a/keras_nlp/models/albert/albert_masked_lm.py +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -97,8 +97,8 @@ class AlbertMaskedLM(Task): embedding_dim=64, hidden_dim=64, intermediate_dim=128, - max_sequence_length=128, - ) + max_sequence_length=128) + # Create a ALBERT masked LM and fit the data. masked_lm = keras_nlp.models.AlbertMaskedLM( backbone, diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py index bf617db852..2e705b813d 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py @@ -97,7 +97,9 @@ class AlbertMaskedLMPreprocessor(AlbertPreprocessor): ["the quick brown fox", "the earth is round"] ) + # Creating sentencepiece tokenizer for ALBERT LM preprocessor bytes_io = io.BytesIO() + sentencepiece.SentencePieceTrainer.train( sentence_iterator=vocab_data.as_numpy_iterator(), model_writer=bytes_io, From 5a036ef7dd498f17560d8f536a2dbafb61b1ebfe Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Fri, 10 Feb 2023 12:17:54 -0800 Subject: [PATCH 09/16] Minor fixes --- keras_nlp/layers/masked_lm_head.py | 4 ++-- keras_nlp/layers/masked_lm_head_test.py | 12 +++++++----- .../albert/albert_masked_lm_preprocessor_test.py | 2 +- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/keras_nlp/layers/masked_lm_head.py b/keras_nlp/layers/masked_lm_head.py index c0d6e18dd9..2f749eb0c7 100644 --- a/keras_nlp/layers/masked_lm_head.py +++ b/keras_nlp/layers/masked_lm_head.py @@ -140,10 +140,10 @@ def __init__( self.vocabulary_size = shape[0] def build(self, input_shapes): - feature_size = input_shapes[-1] - if self.embedding_weights is not None: feature_size = self.embedding_weights.shape[-1] + else: + feature_size = input_shapes[-1] self._dense = keras.layers.Dense( feature_size, diff --git a/keras_nlp/layers/masked_lm_head_test.py b/keras_nlp/layers/masked_lm_head_test.py index 591a24ed1e..f66f0b1918 100644 --- a/keras_nlp/layers/masked_lm_head_test.py +++ b/keras_nlp/layers/masked_lm_head_test.py @@ -46,15 +46,17 @@ def test_valid_call_with_embedding_weights(self): embedding_weights=embedding.embeddings, activation="softmax", ) - encoded_tokens = keras.Input(shape=(32, 16)) + # Use a difference "hidden dim" for the model than "embedding dim", we + # need to support this in the layer. + sequence = keras.Input(shape=(10, 32)) positions = keras.Input(shape=(5,), dtype="int32") - outputs = head(encoded_tokens, mask_positions=positions) - model = keras.Model((encoded_tokens, positions), outputs) - token_data = tf.random.uniform(shape=(4, 32, 16)) + outputs = head(sequence, mask_positions=positions) + model = keras.Model((sequence, positions), outputs) + sequence_data = tf.random.uniform(shape=(4, 10, 32)) position_data = tf.random.uniform( shape=(4, 5), maxval=10, dtype="int32" ) - model((token_data, position_data)) + model((sequence_data, position_data)) def test_get_config_and_from_config(self): head = masked_lm_head.MaskedLMHead( diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py index 72a8c2d481..dd77e9d525 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py @@ -27,7 +27,7 @@ from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer -class AlbertaMaskedLMPreprocessorTest(tf.test.TestCase, parameterized.TestCase): +class AlbertMaskedLMPreprocessorTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): vocab_data = tf.data.Dataset.from_tensor_slices( ["the quick brown fox", "the earth is round"] From fb24d306b95aae5a4a3fb852194ebf4cb7e21731 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Mon, 13 Feb 2023 21:05:36 +0530 Subject: [PATCH 10/16] addressing some comments --- keras_nlp/layers/masked_lm_mask_generator.py | 6 +- .../models/albert/albert_classifier_test.py | 2 +- .../albert_masked_lm_preprocessor_test.py | 71 ++++++++++--------- .../models/albert/albert_masked_lm_test.py | 2 +- keras_nlp/models/albert/albert_tokenizer.py | 4 +- .../models/albert/albert_tokenizer_test.py | 11 +-- 6 files changed, 47 insertions(+), 49 deletions(-) diff --git a/keras_nlp/layers/masked_lm_mask_generator.py b/keras_nlp/layers/masked_lm_mask_generator.py index 069f35b8f6..ab03c530e9 100644 --- a/keras_nlp/layers/masked_lm_mask_generator.py +++ b/keras_nlp/layers/masked_lm_mask_generator.py @@ -147,11 +147,7 @@ def call(self, inputs): # convert dense to ragged. inputs = tf.RaggedTensor.from_tensor(inputs) - ( - token_ids, - mask_positions, - mask_ids, - ) = tf_text.mask_language_model( + (token_ids, mask_positions, mask_ids,) = tf_text.mask_language_model( inputs, item_selector=self._random_selector, mask_values_chooser=self._mask_values_chooser, diff --git a/keras_nlp/models/albert/albert_classifier_test.py b/keras_nlp/models/albert/albert_classifier_test.py index 6f586b9c95..24125b42b7 100644 --- a/keras_nlp/models/albert/albert_classifier_test.py +++ b/keras_nlp/models/albert/albert_classifier_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for BERT classification model.""" +"""Tests for ALBERT classification model.""" import io import os diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py index dd77e9d525..6efc084b0c 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Albert masked language model preprocessor layer.""" +"""Tests for ALBERT masked language model preprocessor layer.""" import io import os @@ -47,6 +47,7 @@ def setUp(self): unk_piece="", bos_piece="[CLS]", eos_piece="[SEP]", + user_defined_symbols="[MASK]", ) proto = bytes_io.getvalue() @@ -59,94 +60,94 @@ def setUp(self): mask_selection_rate=1.0, mask_token_rate=1.0, random_token_rate=0.0, - mask_selection_length=5, + mask_selection_length=4, sequence_length=12, ) def test_preprocess_strings(self): - input_data = " airplane at airport" + input_data = "the quick brown fox" x, y, sw = self.preprocessor(input_data) self.assertAllEqual( - x["token_ids"], [2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0] + x["token_ids"], [1, 4, 4, 4, 4, 2, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual( - x["padding_mask"], [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] + x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] ) - self.assertAllEqual(x["mask_positions"], [1, 0, 0, 0, 0]) - self.assertAllEqual(y, [1, 0, 0, 0, 0]) - self.assertAllEqual(sw, [1.0, 0.0, 0.0, 0.0, 0.0]) + self.assertAllEqual(x["mask_positions"], [1, 2, 3, 4]) + self.assertAllEqual(y, [5, 10, 6, 8]) + self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0]) def test_preprocess_list_of_strings(self): - input_data = [" airplane at airport"] * 4 + input_data = ["the quick brown fox"] * 4 x, y, sw = self.preprocessor(input_data) self.assertAllEqual( - x["token_ids"], [[2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 + x["token_ids"], [[1, 4, 4, 4, 4, 2, 0, 0, 0, 0, 0, 0]] * 4 ) self.assertAllEqual( - x["padding_mask"], [[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 + x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4 ) - self.assertAllEqual(x["mask_positions"], [[1, 0, 0, 0, 0]] * 4) - self.assertAllEqual(y, [[1, 0, 0, 0, 0]] * 4) - self.assertAllEqual(sw, [[1.0, 0.0, 0.0, 0.0, 0.0]] * 4) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4]] * 4) + self.assertAllEqual(y, [[5, 10, 6, 8]] * 4) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0]] * 4) def test_preprocess_dataset(self): - sentences = tf.constant([" airplane at airport"] * 4) + sentences = tf.constant(["the quick brown fox"] * 4) ds = tf.data.Dataset.from_tensor_slices(sentences) ds = ds.map(self.preprocessor) x, y, sw = ds.batch(4).take(1).get_single_element() self.assertAllEqual( - x["token_ids"], [[2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 + x["token_ids"], [[1, 4, 4, 4, 4, 2, 0, 0, 0, 0, 0, 0]] * 4 ) self.assertAllEqual( - x["padding_mask"], [[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 + x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4 ) - self.assertAllEqual(x["mask_positions"], [[1, 0, 0, 0, 0]] * 4) - self.assertAllEqual(y, [[1, 0, 0, 0, 0]] * 4) - self.assertAllEqual(sw, [[1.0, 0.0, 0.0, 0.0, 0.0]] * 4) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4]] * 4) + self.assertAllEqual(y, [[5, 10, 6, 8]] * 4) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0]] * 4) def test_mask_multiple_sentences(self): - sentence_one = tf.constant(" airplane") - sentence_two = tf.constant(" kohli") + sentence_one = tf.constant("the quick") + sentence_two = tf.constant("brown fox") x, y, sw = self.preprocessor((sentence_one, sentence_two)) self.assertAllEqual( - x["token_ids"], [2, 1, 3, 1, 3, 0, 0, 0, 0, 0, 0, 0] + x["token_ids"], [1, 4, 4, 2, 4, 4, 2, 0, 0, 0, 0, 0] ) self.assertAllEqual( - x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] + x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] ) - self.assertAllEqual(x["mask_positions"], [1, 3, 0, 0, 0]) - self.assertAllEqual(y, [1, 1, 0, 0, 0]) - self.assertAllEqual(sw, [1.0, 1.0, 0.0, 0.0, 0.0]) + self.assertAllEqual(x["mask_positions"], [1, 2, 4, 5]) + self.assertAllEqual(y, [5, 10, 6, 8]) + self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0]) def test_no_masking_zero_rate(self): no_mask_preprocessor = AlbertMaskedLMPreprocessor( self.preprocessor.tokenizer, mask_selection_rate=0.0, - mask_selection_length=5, + mask_selection_length=4, sequence_length=12, ) - input_data = " airplane at airport" + input_data = "the quick brown fox" x, y, sw = no_mask_preprocessor(input_data) self.assertAllEqual( - x["token_ids"], [2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0] + x["token_ids"], [1, 5, 10, 6, 8, 2, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual( - x["padding_mask"], [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] + x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] ) - self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0, 0]) - self.assertAllEqual(y, [0, 0, 0, 0, 0]) - self.assertAllEqual(sw, [0.0, 0.0, 0.0, 0.0, 0.0]) + self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0]) + self.assertAllEqual(y, [0, 0, 0, 0]) + self.assertAllEqual(sw, [0.0, 0.0, 0.0, 0.0]) @parameterized.named_parameters( ("tf_format", "tf", "model"), ("keras_format", "keras_v3", "model.keras"), ) def test_saved_model(self, save_format, filename): - input_data = tf.constant([" airplane at airport"]) + input_data = tf.constant(["the quick brown fox"]) inputs = keras.Input(dtype="string", shape=()) outputs = self.preprocessor(inputs) diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py index 93ccc16c8d..0075622268 100644 --- a/keras_nlp/models/albert/albert_masked_lm_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Albert masked language model.""" +"""Tests for ALBERT masked language model.""" import io import os diff --git a/keras_nlp/models/albert/albert_tokenizer.py b/keras_nlp/models/albert/albert_tokenizer.py index cf87a818e7..5ba47af58d 100644 --- a/keras_nlp/models/albert/albert_tokenizer.py +++ b/keras_nlp/models/albert/albert_tokenizer.py @@ -72,8 +72,8 @@ def __init__(self, proto, **kwargs): cls_token = "[CLS]" sep_token = "[SEP]" pad_token = "" - mask_token = "" - for token in [cls_token, sep_token, pad_token]: + mask_token = "[MASK]" + for token in [cls_token, sep_token, pad_token, mask_token]: if token not in self.get_vocabulary(): raise ValueError( f"Cannot find token `'{token}'` in the provided " diff --git a/keras_nlp/models/albert/albert_tokenizer_test.py b/keras_nlp/models/albert/albert_tokenizer_test.py index b7099b8df6..dfe5b39644 100644 --- a/keras_nlp/models/albert/albert_tokenizer_test.py +++ b/keras_nlp/models/albert/albert_tokenizer_test.py @@ -34,7 +34,7 @@ def setUp(self): sentencepiece.SentencePieceTrainer.train( sentence_iterator=vocab_data.as_numpy_iterator(), model_writer=bytes_io, - vocab_size=10, + vocab_size=12, model_type="WORD", pad_id=0, unk_id=1, @@ -44,6 +44,7 @@ def setUp(self): unk_piece="", bos_piece="[CLS]", eos_piece="[SEP]", + user_defined_symbols="[MASK]", ) self.proto = bytes_io.getvalue() @@ -52,21 +53,21 @@ def setUp(self): def test_tokenize(self): input_data = "the quick brown fox" output = self.tokenizer(input_data) - self.assertAllEqual(output, [4, 9, 5, 7]) + self.assertAllEqual(output, [5, 10, 6, 8]) def test_tokenize_batch(self): input_data = tf.constant(["the quick brown fox", "the earth is round"]) output = self.tokenizer(input_data) - self.assertAllEqual(output, [[4, 9, 5, 7], [4, 6, 8, 1]]) + self.assertAllEqual(output, [[5, 10, 6, 8], [5, 7, 9, 11]]) def test_detokenize(self): - input_data = tf.constant([[4, 9, 5, 7]]) + input_data = tf.constant([[5, 10, 6, 8]]) output = self.tokenizer.detokenize(input_data) self.assertEqual(output, tf.constant(["the quick brown fox"])) def test_vocabulary_size(self): tokenizer = AlbertTokenizer(proto=self.proto) - self.assertEqual(tokenizer.vocabulary_size(), 10) + self.assertEqual(tokenizer.vocabulary_size(), 12) @parameterized.named_parameters( ("tf_format", "tf", "model"), From 6755a208b347b62eaaa644431b336a1fca7aa922 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 15 Feb 2023 17:52:35 +0530 Subject: [PATCH 11/16] working on fixing unit tests for masking --- keras_nlp/layers/masked_lm_mask_generator.py | 6 +++++- .../albert_masked_lm_preprocessor_test.py | 20 +++++++++---------- .../models/albert/albert_masked_lm_test.py | 1 + 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/keras_nlp/layers/masked_lm_mask_generator.py b/keras_nlp/layers/masked_lm_mask_generator.py index ab03c530e9..069f35b8f6 100644 --- a/keras_nlp/layers/masked_lm_mask_generator.py +++ b/keras_nlp/layers/masked_lm_mask_generator.py @@ -147,7 +147,11 @@ def call(self, inputs): # convert dense to ragged. inputs = tf.RaggedTensor.from_tensor(inputs) - (token_ids, mask_positions, mask_ids,) = tf_text.mask_language_model( + ( + token_ids, + mask_positions, + mask_ids, + ) = tf_text.mask_language_model( inputs, item_selector=self._random_selector, mask_values_chooser=self._mask_values_chooser, diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py index 6efc084b0c..aa9e5cc3b5 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py @@ -69,13 +69,13 @@ def test_preprocess_strings(self): x, y, sw = self.preprocessor(input_data) self.assertAllEqual( - x["token_ids"], [1, 4, 4, 4, 4, 2, 0, 0, 0, 0, 0, 0] + x["token_ids"], [2, 4, 4, 4, 4, 3, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual( x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual(x["mask_positions"], [1, 2, 3, 4]) - self.assertAllEqual(y, [5, 10, 6, 8]) + self.assertAllEqual(y, [5, 1, 6, 8]) self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0]) def test_preprocess_list_of_strings(self): @@ -83,13 +83,13 @@ def test_preprocess_list_of_strings(self): x, y, sw = self.preprocessor(input_data) self.assertAllEqual( - x["token_ids"], [[1, 4, 4, 4, 4, 2, 0, 0, 0, 0, 0, 0]] * 4 + x["token_ids"], [[2, 4, 4, 4, 4, 3, 0, 0, 0, 0, 0, 0]] * 4 ) self.assertAllEqual( x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4 ) self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4]] * 4) - self.assertAllEqual(y, [[5, 10, 6, 8]] * 4) + self.assertAllEqual(y, [[5, 1, 6, 8]] * 4) self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0]] * 4) def test_preprocess_dataset(self): @@ -98,13 +98,13 @@ def test_preprocess_dataset(self): ds = ds.map(self.preprocessor) x, y, sw = ds.batch(4).take(1).get_single_element() self.assertAllEqual( - x["token_ids"], [[1, 4, 4, 4, 4, 2, 0, 0, 0, 0, 0, 0]] * 4 + x["token_ids"], [[2, 4, 4, 4, 4, 3, 0, 0, 0, 0, 0, 0]] * 4 ) self.assertAllEqual( x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4 ) self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4]] * 4) - self.assertAllEqual(y, [[5, 10, 6, 8]] * 4) + self.assertAllEqual(y, [[5, 1, 6, 8]] * 4) self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0]] * 4) def test_mask_multiple_sentences(self): @@ -113,13 +113,13 @@ def test_mask_multiple_sentences(self): x, y, sw = self.preprocessor((sentence_one, sentence_two)) self.assertAllEqual( - x["token_ids"], [1, 4, 4, 2, 4, 4, 2, 0, 0, 0, 0, 0] + x["token_ids"], [2, 4, 4, 3, 4, 4, 3, 0, 0, 0, 0, 0] ) self.assertAllEqual( x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] ) self.assertAllEqual(x["mask_positions"], [1, 2, 4, 5]) - self.assertAllEqual(y, [5, 10, 6, 8]) + self.assertAllEqual(y, [5, 1, 6, 8]) self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0]) def test_no_masking_zero_rate(self): @@ -133,10 +133,10 @@ def test_no_masking_zero_rate(self): x, y, sw = no_mask_preprocessor(input_data) self.assertAllEqual( - x["token_ids"], [1, 5, 10, 6, 8, 2, 0, 0, 0, 0, 0, 0] + x["token_ids"], [2, 5, 1, 6, 8, 3, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual( - x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] + x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0]) self.assertAllEqual(y, [0, 0, 0, 0]) diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py index 0075622268..c48b10098c 100644 --- a/keras_nlp/models/albert/albert_masked_lm_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_test.py @@ -58,6 +58,7 @@ def setUp(self): unk_piece="", bos_piece="[CLS]", eos_piece="[SEP]", + user_defined_symbols="[MASK]", ) proto = bytes_io.getvalue() From 59f65b55cdca4c4045e9876ebeed0fb8006f2f7c Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 15 Feb 2023 18:16:55 +0530 Subject: [PATCH 12/16] working on fixing unit tests for masking --- keras_nlp/models/albert/albert_masked_lm_preprocessor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py index 2e705b813d..6aa9956f14 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py @@ -113,6 +113,7 @@ class AlbertMaskedLMPreprocessor(AlbertPreprocessor): unk_piece="", bos_piece="[CLS]", eos_piece="[SEP]", + user_defined_symbols="[MASK]" ) proto = bytes_io.getvalue() From d11971c83a5c9bc48dc226cc986c9e5874010595 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 15 Feb 2023 18:41:18 +0530 Subject: [PATCH 13/16] adding mask to preprocessor + fixing tests --- .../albert/albert_masked_lm_preprocessor.py | 2 +- .../albert/albert_masked_lm_preprocessor_test.py | 12 ++++++------ keras_nlp/models/albert/albert_masked_lm_test.py | 13 ++++++------- .../models/albert/albert_preprocessor_test.py | 15 ++++++++------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py index 6aa9956f14..1c874a0b7b 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py @@ -103,7 +103,7 @@ class AlbertMaskedLMPreprocessor(AlbertPreprocessor): sentencepiece.SentencePieceTrainer.train( sentence_iterator=vocab_data.as_numpy_iterator(), model_writer=bytes_io, - vocab_size=10, + vocab_size=12, model_type="WORD", pad_id=0, unk_id=1, diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py index aa9e5cc3b5..c77f24c213 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor_test.py @@ -37,7 +37,7 @@ def setUp(self): sentencepiece.SentencePieceTrainer.train( sentence_iterator=vocab_data.as_numpy_iterator(), model_writer=bytes_io, - vocab_size=10, + vocab_size=12, model_type="WORD", pad_id=0, unk_id=1, @@ -75,7 +75,7 @@ def test_preprocess_strings(self): x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual(x["mask_positions"], [1, 2, 3, 4]) - self.assertAllEqual(y, [5, 1, 6, 8]) + self.assertAllEqual(y, [5, 10, 6, 8]) self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0]) def test_preprocess_list_of_strings(self): @@ -89,7 +89,7 @@ def test_preprocess_list_of_strings(self): x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4 ) self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4]] * 4) - self.assertAllEqual(y, [[5, 1, 6, 8]] * 4) + self.assertAllEqual(y, [[5, 10, 6, 8]] * 4) self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0]] * 4) def test_preprocess_dataset(self): @@ -104,7 +104,7 @@ def test_preprocess_dataset(self): x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]] * 4 ) self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4]] * 4) - self.assertAllEqual(y, [[5, 1, 6, 8]] * 4) + self.assertAllEqual(y, [[5, 10, 6, 8]] * 4) self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0]] * 4) def test_mask_multiple_sentences(self): @@ -119,7 +119,7 @@ def test_mask_multiple_sentences(self): x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] ) self.assertAllEqual(x["mask_positions"], [1, 2, 4, 5]) - self.assertAllEqual(y, [5, 1, 6, 8]) + self.assertAllEqual(y, [5, 10, 6, 8]) self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0]) def test_no_masking_zero_rate(self): @@ -133,7 +133,7 @@ def test_no_masking_zero_rate(self): x, y, sw = no_mask_preprocessor(input_data) self.assertAllEqual( - x["token_ids"], [2, 5, 1, 6, 8, 3, 0, 0, 0, 0, 0, 0] + x["token_ids"], [2, 5, 10, 6, 8, 3, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual( x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py index c48b10098c..09630d77fc 100644 --- a/keras_nlp/models/albert/albert_masked_lm_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_test.py @@ -41,14 +41,14 @@ def setUp(self): max_sequence_length=128, ) vocab_data = tf.data.Dataset.from_tensor_slices( - ["the quick brown fox", "the earth is round"] + ["the quick brown fox", "the earth is round", "an eagle flew"] ) bytes_io = io.BytesIO() sentencepiece.SentencePieceTrainer.train( sentence_iterator=vocab_data.as_numpy_iterator(), model_writer=bytes_io, - vocab_size=10, + vocab_size=15, model_type="WORD", pad_id=0, unk_id=1, @@ -85,11 +85,10 @@ def setUp(self): self.raw_batch = tf.constant( [ - " airplane at airport", - " the airplane is the best", - " the best airport", - " kohli is the best", - ] + "quick brown fox", + "eagle flew over fox", + "the eagle flew quick", + "a brown eagle", ] ) self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] self.raw_dataset = tf.data.Dataset.from_tensor_slices( diff --git a/keras_nlp/models/albert/albert_preprocessor_test.py b/keras_nlp/models/albert/albert_preprocessor_test.py index ee6038f839..fae2445be7 100644 --- a/keras_nlp/models/albert/albert_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_preprocessor_test.py @@ -35,7 +35,7 @@ def setUp(self): sentencepiece.SentencePieceTrainer.train( sentence_iterator=vocab_data.as_numpy_iterator(), model_writer=bytes_io, - vocab_size=10, + vocab_size=12, model_type="WORD", pad_id=0, unk_id=1, @@ -45,6 +45,7 @@ def setUp(self): unk_piece="", bos_piece="[CLS]", eos_piece="[SEP]", + user_defined_symbols="[MASK]" ) self.proto = bytes_io.getvalue() @@ -57,7 +58,7 @@ def test_tokenize_strings(self): input_data = "the quick brown fox" output = self.preprocessor(input_data) self.assertAllEqual( - output["token_ids"], [2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0] + output["token_ids"], [2, 5, 10, 6, 8, 3, 0, 0, 0, 0, 0, 0] ) self.assertAllEqual( output["segment_ids"], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] @@ -72,7 +73,7 @@ def test_tokenize_list_of_strings(self): output = self.preprocessor(input_data) self.assertAllEqual( output["token_ids"], - [[2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0]] * 4, + [[2, 5, 10, 6, 8, 3, 0, 0, 0, 0, 0, 0]] * 4, ) self.assertAllEqual( output["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 @@ -88,7 +89,7 @@ def test_tokenize_labeled_batch(self): x_out, y_out, sw_out = self.preprocessor(x, y, sw) self.assertAllEqual( x_out["token_ids"], - [[2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0]] * 4, + [[2, 5, 10, 6, 8, 3, 0, 0, 0, 0, 0, 0]] * 4, ) self.assertAllEqual( x_out["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 @@ -108,7 +109,7 @@ def test_tokenize_labeled_dataset(self): x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element() self.assertAllEqual( x_out["token_ids"], - [[2, 4, 9, 5, 7, 3, 0, 0, 0, 0, 0, 0]] * 4, + [[2, 5, 10, 6, 8, 3, 0, 0, 0, 0, 0, 0]] * 4, ) self.assertAllEqual( x_out["segment_ids"], [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] * 4 @@ -125,7 +126,7 @@ def test_tokenize_multiple_sentences(self): output = self.preprocessor((sentence_one, sentence_two)) self.assertAllEqual( output["token_ids"], - [2, 4, 9, 5, 7, 3, 4, 6, 3, 0, 0, 0], + [2, 5, 10, 6, 8, 3, 5, 7, 3, 0, 0, 0], ) self.assertAllEqual( output["segment_ids"], [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0] @@ -142,7 +143,7 @@ def test_tokenize_multiple_batched_sentences(self): output = self.preprocessor((sentence_one, sentence_two)) self.assertAllEqual( output["token_ids"], - [[2, 4, 9, 5, 7, 3, 4, 6, 3, 0, 0, 0]] * 4, + [[2, 5, 10, 6, 8, 3, 5, 7, 3, 0, 0, 0]] * 4, ) self.assertAllEqual( output["segment_ids"], [[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0]] * 4 From 4498ab865db5693c406d55676f6f6b8ac098b9fa Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 15 Feb 2023 18:41:48 +0530 Subject: [PATCH 14/16] code format --- keras_nlp/models/albert/albert_masked_lm_test.py | 3 ++- keras_nlp/models/albert/albert_preprocessor_test.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py index 09630d77fc..557b1bd411 100644 --- a/keras_nlp/models/albert/albert_masked_lm_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_test.py @@ -88,7 +88,8 @@ def setUp(self): "quick brown fox", "eagle flew over fox", "the eagle flew quick", - "a brown eagle", ] + "a brown eagle", + ] ) self.preprocessed_batch = self.preprocessor(self.raw_batch)[0] self.raw_dataset = tf.data.Dataset.from_tensor_slices( diff --git a/keras_nlp/models/albert/albert_preprocessor_test.py b/keras_nlp/models/albert/albert_preprocessor_test.py index fae2445be7..53639517ea 100644 --- a/keras_nlp/models/albert/albert_preprocessor_test.py +++ b/keras_nlp/models/albert/albert_preprocessor_test.py @@ -45,7 +45,7 @@ def setUp(self): unk_piece="", bos_piece="[CLS]", eos_piece="[SEP]", - user_defined_symbols="[MASK]" + user_defined_symbols="[MASK]", ) self.proto = bytes_io.getvalue() From 9036ceca772c5338f82d4964be1536ad3456ee24 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 15 Feb 2023 19:42:13 +0530 Subject: [PATCH 15/16] fixing classifier test failures --- keras_nlp/models/albert/albert_classifier_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_nlp/models/albert/albert_classifier_test.py b/keras_nlp/models/albert/albert_classifier_test.py index 2f447194f0..508c11a231 100644 --- a/keras_nlp/models/albert/albert_classifier_test.py +++ b/keras_nlp/models/albert/albert_classifier_test.py @@ -57,6 +57,7 @@ def setUp(self): unk_piece="", bos_piece="[CLS]", eos_piece="[SEP]", + user_defined_symbols="[MASK]" ) self.proto = bytes_io.getvalue() From a82350d2860d079d0c4aacf5ed62679790397b75 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 15 Feb 2023 19:42:52 +0530 Subject: [PATCH 16/16] fixing formatting --- keras_nlp/models/albert/albert_classifier_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/models/albert/albert_classifier_test.py b/keras_nlp/models/albert/albert_classifier_test.py index 508c11a231..40fec53486 100644 --- a/keras_nlp/models/albert/albert_classifier_test.py +++ b/keras_nlp/models/albert/albert_classifier_test.py @@ -57,7 +57,7 @@ def setUp(self): unk_piece="", bos_piece="[CLS]", eos_piece="[SEP]", - user_defined_symbols="[MASK]" + user_defined_symbols="[MASK]", ) self.proto = bytes_io.getvalue()