diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py
index db70d43bce..893d93000b 100644
--- a/keras_nlp/models/__init__.py
+++ b/keras_nlp/models/__init__.py
@@ -16,6 +16,7 @@
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.models.bart.bart_backbone import BartBackbone
+from keras_nlp.models.bart.bart_tokenizer import BartTokenizer
from keras_nlp.models.bert.bert_backbone import BertBackbone
from keras_nlp.models.bert.bert_classifier import BertClassifier
from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
diff --git a/keras_nlp/models/bart/bart_backbone.py b/keras_nlp/models/bart/bart_backbone.py
index 32c5287310..a66a6210cd 100644
--- a/keras_nlp/models/bart/bart_backbone.py
+++ b/keras_nlp/models/bart/bart_backbone.py
@@ -14,6 +14,8 @@
"""BART backbone model."""
+import copy
+
import tensorflow as tf
from tensorflow import keras
@@ -21,6 +23,8 @@
from keras_nlp.layers.transformer_decoder import TransformerDecoder
from keras_nlp.layers.transformer_encoder import TransformerEncoder
from keras_nlp.models.backbone import Backbone
+from keras_nlp.models.bart.bart_presets import backbone_presets
+from keras_nlp.utils.python_utils import classproperty
def bart_kernel_initializer(stddev=0.02):
@@ -247,3 +251,7 @@ def get_config(self):
@property
def token_embedding(self):
return self.get_layer("token_embedding")
+
+ @classproperty
+ def presets(cls):
+ return copy.deepcopy(backbone_presets)
diff --git a/keras_nlp/models/bart/bart_presets.py b/keras_nlp/models/bart/bart_presets.py
new file mode 100644
index 0000000000..053eaddfd7
--- /dev/null
+++ b/keras_nlp/models/bart/bart_presets.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""BART model preset configurations."""
+
+backbone_presets = {
+ "bart_base_en": {
+ "config": {
+ "vocabulary_size": 50265,
+ "num_layers": 6,
+ "num_heads": 12,
+ "hidden_dim": 768,
+ "intermediate_dim": 3072,
+ "dropout": 0.1,
+ "max_sequence_length": 1024,
+ },
+ "preprocessor_config": {},
+ "description": (
+ "Base size of BART where case is maintained. "
+ "Trained on a 160GB English dataset comprising BookCorpus, "
+ "English Wikipedia and CommonCrawl."
+ ),
+ "weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/model.h5",
+ "weights_hash": "5b59403f0cafafbd89680e0785791163",
+ "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/vocab.json",
+ "vocabulary_hash": "be4d3c6f3f5495426b2c03b334334354",
+ "merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/merges.txt",
+ "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e",
+ },
+ "bart_large_en": {
+ "config": {
+ "vocabulary_size": 50265,
+ "num_layers": 12,
+ "num_heads": 16,
+ "hidden_dim": 1024,
+ "intermediate_dim": 4096,
+ "dropout": 0.1,
+ "max_sequence_length": 1024,
+ },
+ "preprocessor_config": {},
+ "description": (
+ "Large size of BART where case is maintained. "
+ "Trained on a 160GB English dataset comprising BookCorpus, "
+ "English Wikipedia and CommonCrawl."
+ ),
+ "weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/model.h5",
+ "weights_hash": "6bfe7e591af8c5699ce6f9f18753af9a",
+ "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/vocab.json",
+ "vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596",
+ "merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/merges.txt",
+ "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e",
+ },
+}
diff --git a/keras_nlp/models/bart/bart_presets_test.py b/keras_nlp/models/bart/bart_presets_test.py
new file mode 100644
index 0000000000..2015877ddf
--- /dev/null
+++ b/keras_nlp/models/bart/bart_presets_test.py
@@ -0,0 +1,139 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for loading pretrained model presets."""
+
+import pytest
+import tensorflow as tf
+from absl.testing import parameterized
+
+from keras_nlp.models.bart.bart_backbone import BartBackbone
+from keras_nlp.models.bart.bart_tokenizer import BartTokenizer
+
+
+@pytest.mark.large
+class BartPresetSmokeTest(tf.test.TestCase, parameterized.TestCase):
+ """
+ A smoke test for BART presets we run continuously.
+
+ This only tests the smallest weights we have available. Run with:
+ `pytest keras_nlp/models/bart/bart_presets_test.py --run_large`
+ """
+
+ def test_tokenizer_output(self):
+ tokenizer = BartTokenizer.from_preset(
+ "bart_base_en",
+ )
+ outputs = tokenizer("The quick brown fox.")
+ expected_outputs = [133, 2119, 6219, 23602, 4]
+ self.assertAllEqual(outputs, expected_outputs)
+
+ @parameterized.named_parameters(
+ ("preset_weights", True), ("random_weights", False)
+ )
+ def test_backbone_output(self, load_weights):
+ input_data = {
+ "encoder_token_ids": tf.constant([[0, 133, 2119, 2]]),
+ "encoder_padding_mask": tf.constant([[1, 1, 1, 1]]),
+ "decoder_token_ids": tf.constant([[0, 7199, 14, 2119, 2]]),
+ "decoder_padding_mask": tf.constant([[1, 1, 1, 1, 1]]),
+ }
+ model = BartBackbone.from_preset(
+ "bart_base_en", load_weights=load_weights
+ )
+ outputs = model(input_data)
+ if load_weights:
+ encoder_output = outputs["encoder_sequence_output"][0, 0, :5]
+ expected_encoder_output = [-0.033, 0.013, -0.003, -0.012, -0.002]
+ decoder_output = outputs["decoder_sequence_output"][0, 0, :5]
+ expected_decoder_output = [2.516, 2.489, 0.695, 8.057, 1.245]
+
+ self.assertAllClose(
+ encoder_output, expected_encoder_output, atol=0.01, rtol=0.01
+ )
+ self.assertAllClose(
+ decoder_output, expected_decoder_output, atol=0.01, rtol=0.01
+ )
+
+ @parameterized.named_parameters(
+ ("bart_tokenizer", BartTokenizer),
+ ("bart", BartBackbone),
+ )
+ def test_preset_docstring(self, cls):
+ """Check we did our docstring formatting correctly."""
+ for name in cls.presets:
+ self.assertRegex(cls.from_preset.__doc__, name)
+
+ @parameterized.named_parameters(
+ ("bart_tokenizer", BartTokenizer),
+ ("bart", BartBackbone),
+ )
+ def test_unknown_preset_error(self, cls):
+ # Not a preset name
+ with self.assertRaises(ValueError):
+ cls.from_preset("bart_base_en_clowntown")
+
+
+@pytest.mark.extra_large
+class BartPresetFullTest(tf.test.TestCase, parameterized.TestCase):
+ """
+ Test the full enumeration of our preset.
+
+ This tests every BART preset and is only run manually.
+ Run with:
+ `pytest keras_nlp/models/bart/bart_presets_test.py --run_extra_large`
+ """
+
+ @parameterized.named_parameters(
+ ("preset_weights", True), ("random_weights", False)
+ )
+ def test_load_bart(self, load_weights):
+ for preset in BartBackbone.presets:
+ model = BartBackbone.from_preset(preset, load_weights=load_weights)
+ input_data = {
+ "encoder_token_ids": tf.random.uniform(
+ shape=(1, 1024),
+ dtype=tf.int64,
+ maxval=model.vocabulary_size,
+ ),
+ "encoder_padding_mask": tf.constant(
+ [1] * 768 + [0] * 256, shape=(1, 1024)
+ ),
+ "decoder_token_ids": tf.random.uniform(
+ shape=(1, 1024),
+ dtype=tf.int64,
+ maxval=model.vocabulary_size,
+ ),
+ "decoder_padding_mask": tf.constant(
+ [1] * 489 + [0] * 535, shape=(1, 1024)
+ ),
+ }
+ model(input_data)
+
+ def test_load_tokenizers(self):
+ for preset in BartTokenizer.presets:
+ tokenizer = BartTokenizer.from_preset(preset)
+ tokenizer("The quick brown fox.")
diff --git a/keras_nlp/models/bart/bart_tokenizer.py b/keras_nlp/models/bart/bart_tokenizer.py
new file mode 100644
index 0000000000..c5e2a25e1f
--- /dev/null
+++ b/keras_nlp/models/bart/bart_tokenizer.py
@@ -0,0 +1,120 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BART tokenizer."""
+
+import copy
+
+from tensorflow import keras
+
+from keras_nlp.models.bart.bart_presets import backbone_presets
+from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer
+from keras_nlp.utils.python_utils import classproperty
+
+
+@keras.utils.register_keras_serializable(package="keras_nlp")
+class BartTokenizer(BytePairTokenizer):
+ """A BART tokenizer using Byte-Pair Encoding subword segmentation.
+
+ This tokenizer class will tokenize raw strings into integer sequences and
+ is based on `keras_nlp.tokenizers.BytePairTokenizer`. Unlike the
+ underlying tokenizer, it will check for all special tokens needed by BART
+ models and provides a `from_preset()` method to automatically download
+ a matching vocabulary for a BART preset.
+
+ This tokenizer does not provide truncation or padding of inputs. It can be
+ combined with a `keras_nlp.models.BartPreprocessor` layer for input
+ packing.
+
+ If input is a batch of strings (rank > 0), the layer will output a
+ `tf.RaggedTensor` where the last dimension of the output is ragged.
+
+ If input is a scalar string (rank == 0), the layer will output a dense
+ `tf.Tensor` with static shape `[None]`.
+
+ Args:
+ vocabulary: string or dict, maps token to integer ids. If it is a
+ string, it should be the file path to a json file.
+ merges: string or list, contains the merge rule. If it is a string,
+ it should be the file path to merge rules. The merge rule file
+ should have one merge rule per line. Every merge rule contains
+ merge entities separated by a space.
+
+ Examples:
+
+ Batched inputs.
+ >>> vocab = {"": 0, "": 1, "": 2, "": 3}
+ >>> vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6}
+ >>> merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
+ >>> merges += ["Ġ f", "o x", "Ġf ox"]
+ >>> tokenizer = keras_nlp.models.RobertaTokenizer(
+ ... vocabulary=vocab, merges=merges
+ ... )
+ >>> tokenizer(["a quick fox", "a fox quick"])
+
+
+ Unbatched input.
+ >>> vocab = {"": 0, "": 1, "": 2, "": 3}
+ >>> vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6}
+ >>> merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
+ >>> merges += ["Ġ f", "o x", "Ġf ox"]
+ >>> tokenizer = keras_nlp.models.RobertaTokenizer(
+ ... vocabulary=vocab, merges=merges
+ ... )
+ >>> tokenizer("a quick fox")
+
+
+ Detokenization.
+ >>> vocab = {"": 0, "": 1, "": 2, "": 3}
+ >>> vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6}
+ >>> merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
+ >>> merges += ["Ġ f", "o x", "Ġf ox"]
+ >>> tokenizer = keras_nlp.models.RobertaTokenizer(
+ ... vocabulary=vocab, merges=merges
+ ... )
+ >>> tokenizer.detokenize(tokenizer("a quick fox")).numpy().decode('utf-8')
+ 'a quick fox'
+ """
+
+ def __init__(
+ self,
+ vocabulary,
+ merges,
+ **kwargs,
+ ):
+ super().__init__(
+ vocabulary=vocabulary,
+ merges=merges,
+ **kwargs,
+ )
+
+ # Check for necessary special tokens.
+ start_token = ""
+ pad_token = ""
+ end_token = ""
+ for token in [start_token, pad_token, end_token]:
+ if token not in self.get_vocabulary():
+ raise ValueError(
+ f"Cannot find token `'{token}'` in the provided "
+ f"`vocabulary`. Please provide `'{token}'` in your "
+ "`vocabulary` or use a pretrained `vocabulary` name."
+ )
+
+ self.start_token_id = self.token_to_id(start_token)
+ self.pad_token_id = self.token_to_id(pad_token)
+ self.end_token_id = self.token_to_id(end_token)
+
+ @classproperty
+ def presets(cls):
+ return copy.deepcopy(backbone_presets)
diff --git a/keras_nlp/models/bart/bart_tokenizer_test.py b/keras_nlp/models/bart/bart_tokenizer_test.py
new file mode 100644
index 0000000000..cd8fd62437
--- /dev/null
+++ b/keras_nlp/models/bart/bart_tokenizer_test.py
@@ -0,0 +1,86 @@
+# Copyright 2023 The KerasNLP Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for BART tokenizer."""
+
+import os
+
+import tensorflow as tf
+from absl.testing import parameterized
+from tensorflow import keras
+
+from keras_nlp.models.bart.bart_tokenizer import BartTokenizer
+
+
+class BartTokenizerTest(tf.test.TestCase, parameterized.TestCase):
+ def setUp(self):
+ vocab = {
+ "": 0,
+ "": 1,
+ "": 2,
+ "Ġair": 3,
+ "plane": 4,
+ "Ġat": 5,
+ "port": 6,
+ "Ġkoh": 7,
+ "li": 8,
+ "Ġis": 9,
+ "Ġthe": 10,
+ "Ġbest": 11,
+ }
+
+ merges = ["Ġ a", "Ġ t", "Ġ k", "Ġ i", "Ġ b", "Ġa i", "p l", "n e"]
+ merges += ["Ġa t", "p o", "r t", "o h", "l i", "Ġi s", "Ġb e", "s t"]
+ merges += ["Ġt h", "Ġai r", "pl a", "Ġk oh", "Ġth e", "Ġbe st", "po rt"]
+ merges += ["pla ne"]
+
+ self.tokenizer = BartTokenizer(vocabulary=vocab, merges=merges)
+
+ def test_tokenize(self):
+ input_data = " airplane at airport"
+ output = self.tokenizer(input_data)
+ self.assertAllEqual(output, [3, 4, 5, 3, 6])
+
+ def test_tokenize_batch(self):
+ input_data = tf.constant([" airplane at airport", " kohli is the best"])
+ output = self.tokenizer(input_data)
+ self.assertAllEqual(output, [[3, 4, 5, 3, 6], [7, 8, 9, 10, 11]])
+
+ def test_detokenize(self):
+ input_tokens = [[3, 4, 5, 3, 6]]
+ output = self.tokenizer.detokenize(input_tokens)
+ self.assertAllEqual(output, [" airplane at airport"])
+
+ def test_vocabulary_size(self):
+ self.assertEqual(self.tokenizer.vocabulary_size(), 12)
+
+ @parameterized.named_parameters(
+ ("tf_format", "tf", "model"),
+ ("keras_format", "keras_v3", "model.keras"),
+ )
+ def test_saved_model(self, save_format, filename):
+ input_data = tf.constant([" airplane at airport"])
+
+ inputs = keras.Input(dtype="string", shape=())
+ outputs = self.tokenizer(inputs)
+ model = keras.Model(inputs, outputs)
+
+ path = os.path.join(self.get_temp_dir(), filename)
+ model.save(path, save_format=save_format)
+
+ restored_model = keras.models.load_model(path)
+ self.assertAllEqual(
+ model(input_data),
+ restored_model(input_data),
+ )