diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index db70d43bce..893d93000b 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -16,6 +16,7 @@ from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer from keras_nlp.models.bart.bart_backbone import BartBackbone +from keras_nlp.models.bart.bart_tokenizer import BartTokenizer from keras_nlp.models.bert.bert_backbone import BertBackbone from keras_nlp.models.bert.bert_classifier import BertClassifier from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor diff --git a/keras_nlp/models/bart/bart_backbone.py b/keras_nlp/models/bart/bart_backbone.py index 32c5287310..a66a6210cd 100644 --- a/keras_nlp/models/bart/bart_backbone.py +++ b/keras_nlp/models/bart/bart_backbone.py @@ -14,6 +14,8 @@ """BART backbone model.""" +import copy + import tensorflow as tf from tensorflow import keras @@ -21,6 +23,8 @@ from keras_nlp.layers.transformer_decoder import TransformerDecoder from keras_nlp.layers.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone +from keras_nlp.models.bart.bart_presets import backbone_presets +from keras_nlp.utils.python_utils import classproperty def bart_kernel_initializer(stddev=0.02): @@ -247,3 +251,7 @@ def get_config(self): @property def token_embedding(self): return self.get_layer("token_embedding") + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bart/bart_presets.py b/keras_nlp/models/bart/bart_presets.py new file mode 100644 index 0000000000..053eaddfd7 --- /dev/null +++ b/keras_nlp/models/bart/bart_presets.py @@ -0,0 +1,63 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BART model preset configurations.""" + +backbone_presets = { + "bart_base_en": { + "config": { + "vocabulary_size": 50265, + "num_layers": 6, + "num_heads": 12, + "hidden_dim": 768, + "intermediate_dim": 3072, + "dropout": 0.1, + "max_sequence_length": 1024, + }, + "preprocessor_config": {}, + "description": ( + "Base size of BART where case is maintained. " + "Trained on a 160GB English dataset comprising BookCorpus, " + "English Wikipedia and CommonCrawl." + ), + "weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/model.h5", + "weights_hash": "5b59403f0cafafbd89680e0785791163", + "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/vocab.json", + "vocabulary_hash": "be4d3c6f3f5495426b2c03b334334354", + "merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/merges.txt", + "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + }, + "bart_large_en": { + "config": { + "vocabulary_size": 50265, + "num_layers": 12, + "num_heads": 16, + "hidden_dim": 1024, + "intermediate_dim": 4096, + "dropout": 0.1, + "max_sequence_length": 1024, + }, + "preprocessor_config": {}, + "description": ( + "Large size of BART where case is maintained. " + "Trained on a 160GB English dataset comprising BookCorpus, " + "English Wikipedia and CommonCrawl." + ), + "weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/model.h5", + "weights_hash": "6bfe7e591af8c5699ce6f9f18753af9a", + "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/vocab.json", + "vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596", + "merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/merges.txt", + "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + }, +} diff --git a/keras_nlp/models/bart/bart_presets_test.py b/keras_nlp/models/bart/bart_presets_test.py new file mode 100644 index 0000000000..2015877ddf --- /dev/null +++ b/keras_nlp/models/bart/bart_presets_test.py @@ -0,0 +1,139 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for loading pretrained model presets.""" + +import pytest +import tensorflow as tf +from absl.testing import parameterized + +from keras_nlp.models.bart.bart_backbone import BartBackbone +from keras_nlp.models.bart.bart_tokenizer import BartTokenizer + + +@pytest.mark.large +class BartPresetSmokeTest(tf.test.TestCase, parameterized.TestCase): + """ + A smoke test for BART presets we run continuously. + + This only tests the smallest weights we have available. Run with: + `pytest keras_nlp/models/bart/bart_presets_test.py --run_large` + """ + + def test_tokenizer_output(self): + tokenizer = BartTokenizer.from_preset( + "bart_base_en", + ) + outputs = tokenizer("The quick brown fox.") + expected_outputs = [133, 2119, 6219, 23602, 4] + self.assertAllEqual(outputs, expected_outputs) + + @parameterized.named_parameters( + ("preset_weights", True), ("random_weights", False) + ) + def test_backbone_output(self, load_weights): + input_data = { + "encoder_token_ids": tf.constant([[0, 133, 2119, 2]]), + "encoder_padding_mask": tf.constant([[1, 1, 1, 1]]), + "decoder_token_ids": tf.constant([[0, 7199, 14, 2119, 2]]), + "decoder_padding_mask": tf.constant([[1, 1, 1, 1, 1]]), + } + model = BartBackbone.from_preset( + "bart_base_en", load_weights=load_weights + ) + outputs = model(input_data) + if load_weights: + encoder_output = outputs["encoder_sequence_output"][0, 0, :5] + expected_encoder_output = [-0.033, 0.013, -0.003, -0.012, -0.002] + decoder_output = outputs["decoder_sequence_output"][0, 0, :5] + expected_decoder_output = [2.516, 2.489, 0.695, 8.057, 1.245] + + self.assertAllClose( + encoder_output, expected_encoder_output, atol=0.01, rtol=0.01 + ) + self.assertAllClose( + decoder_output, expected_decoder_output, atol=0.01, rtol=0.01 + ) + + @parameterized.named_parameters( + ("bart_tokenizer", BartTokenizer), + ("bart", BartBackbone), + ) + def test_preset_docstring(self, cls): + """Check we did our docstring formatting correctly.""" + for name in cls.presets: + self.assertRegex(cls.from_preset.__doc__, name) + + @parameterized.named_parameters( + ("bart_tokenizer", BartTokenizer), + ("bart", BartBackbone), + ) + def test_unknown_preset_error(self, cls): + # Not a preset name + with self.assertRaises(ValueError): + cls.from_preset("bart_base_en_clowntown") + + +@pytest.mark.extra_large +class BartPresetFullTest(tf.test.TestCase, parameterized.TestCase): + """ + Test the full enumeration of our preset. + + This tests every BART preset and is only run manually. + Run with: + `pytest keras_nlp/models/bart/bart_presets_test.py --run_extra_large` + """ + + @parameterized.named_parameters( + ("preset_weights", True), ("random_weights", False) + ) + def test_load_bart(self, load_weights): + for preset in BartBackbone.presets: + model = BartBackbone.from_preset(preset, load_weights=load_weights) + input_data = { + "encoder_token_ids": tf.random.uniform( + shape=(1, 1024), + dtype=tf.int64, + maxval=model.vocabulary_size, + ), + "encoder_padding_mask": tf.constant( + [1] * 768 + [0] * 256, shape=(1, 1024) + ), + "decoder_token_ids": tf.random.uniform( + shape=(1, 1024), + dtype=tf.int64, + maxval=model.vocabulary_size, + ), + "decoder_padding_mask": tf.constant( + [1] * 489 + [0] * 535, shape=(1, 1024) + ), + } + model(input_data) + + def test_load_tokenizers(self): + for preset in BartTokenizer.presets: + tokenizer = BartTokenizer.from_preset(preset) + tokenizer("The quick brown fox.") diff --git a/keras_nlp/models/bart/bart_tokenizer.py b/keras_nlp/models/bart/bart_tokenizer.py new file mode 100644 index 0000000000..c5e2a25e1f --- /dev/null +++ b/keras_nlp/models/bart/bart_tokenizer.py @@ -0,0 +1,120 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BART tokenizer.""" + +import copy + +from tensorflow import keras + +from keras_nlp.models.bart.bart_presets import backbone_presets +from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer +from keras_nlp.utils.python_utils import classproperty + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class BartTokenizer(BytePairTokenizer): + """A BART tokenizer using Byte-Pair Encoding subword segmentation. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.BytePairTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by BART + models and provides a `from_preset()` method to automatically download + a matching vocabulary for a BART preset. + + This tokenizer does not provide truncation or padding of inputs. It can be + combined with a `keras_nlp.models.BartPreprocessor` layer for input + packing. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + vocabulary: string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, + it should be the file path to merge rules. The merge rule file + should have one merge rule per line. Every merge rule contains + merge entities separated by a space. + + Examples: + + Batched inputs. + >>> vocab = {"": 0, "": 1, "": 2, "": 3} + >>> vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6} + >>> merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"] + >>> merges += ["Ġ f", "o x", "Ġf ox"] + >>> tokenizer = keras_nlp.models.RobertaTokenizer( + ... vocabulary=vocab, merges=merges + ... ) + >>> tokenizer(["a quick fox", "a fox quick"]) + + + Unbatched input. + >>> vocab = {"": 0, "": 1, "": 2, "": 3} + >>> vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6} + >>> merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"] + >>> merges += ["Ġ f", "o x", "Ġf ox"] + >>> tokenizer = keras_nlp.models.RobertaTokenizer( + ... vocabulary=vocab, merges=merges + ... ) + >>> tokenizer("a quick fox") + + + Detokenization. + >>> vocab = {"": 0, "": 1, "": 2, "": 3} + >>> vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6} + >>> merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"] + >>> merges += ["Ġ f", "o x", "Ġf ox"] + >>> tokenizer = keras_nlp.models.RobertaTokenizer( + ... vocabulary=vocab, merges=merges + ... ) + >>> tokenizer.detokenize(tokenizer("a quick fox")).numpy().decode('utf-8') + 'a quick fox' + """ + + def __init__( + self, + vocabulary, + merges, + **kwargs, + ): + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) + + # Check for necessary special tokens. + start_token = "" + pad_token = "" + end_token = "" + for token in [start_token, pad_token, end_token]: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your " + "`vocabulary` or use a pretrained `vocabulary` name." + ) + + self.start_token_id = self.token_to_id(start_token) + self.pad_token_id = self.token_to_id(pad_token) + self.end_token_id = self.token_to_id(end_token) + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bart/bart_tokenizer_test.py b/keras_nlp/models/bart/bart_tokenizer_test.py new file mode 100644 index 0000000000..cd8fd62437 --- /dev/null +++ b/keras_nlp/models/bart/bart_tokenizer_test.py @@ -0,0 +1,86 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BART tokenizer.""" + +import os + +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.bart.bart_tokenizer import BartTokenizer + + +class BartTokenizerTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + vocab = { + "": 0, + "": 1, + "": 2, + "Ġair": 3, + "plane": 4, + "Ġat": 5, + "port": 6, + "Ġkoh": 7, + "li": 8, + "Ġis": 9, + "Ġthe": 10, + "Ġbest": 11, + } + + merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "Ġa i", "p l", "n e"] + merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"] + merges += ["Ġt h", "Ġai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"] + merges += ["pla ne"] + + self.tokenizer = BartTokenizer(vocabulary=vocab, merges=merges) + + def test_tokenize(self): + input_data = " airplane at airport" + output = self.tokenizer(input_data) + self.assertAllEqual(output, [3, 4, 5, 3, 6]) + + def test_tokenize_batch(self): + input_data = tf.constant([" airplane at airport", " kohli is the best"]) + output = self.tokenizer(input_data) + self.assertAllEqual(output, [[3, 4, 5, 3, 6], [7, 8, 9, 10, 11]]) + + def test_detokenize(self): + input_tokens = [[3, 4, 5, 3, 6]] + output = self.tokenizer.detokenize(input_tokens) + self.assertAllEqual(output, [" airplane at airport"]) + + def test_vocabulary_size(self): + self.assertEqual(self.tokenizer.vocabulary_size(), 12) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + input_data = tf.constant([" airplane at airport"]) + + inputs = keras.Input(dtype="string", shape=()) + outputs = self.tokenizer(inputs) + model = keras.Model(inputs, outputs) + + path = os.path.join(self.get_temp_dir(), filename) + model.save(path, save_format=save_format) + + restored_model = keras.models.load_model(path) + self.assertAllEqual( + model(input_data), + restored_model(input_data), + )