Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BartTokenizer and BART Presets #685

Merged
merged 4 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor
from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.models.bart.bart_backbone import BartBackbone
from keras_nlp.models.bart.bart_tokenizer import BartTokenizer
from keras_nlp.models.bert.bert_backbone import BertBackbone
from keras_nlp.models.bert.bert_classifier import BertClassifier
from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
Expand Down
8 changes: 8 additions & 0 deletions keras_nlp/models/bart/bart_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@

"""BART backbone model."""

import copy

import tensorflow as tf
from tensorflow import keras

from keras_nlp.layers.position_embedding import PositionEmbedding
from keras_nlp.layers.transformer_decoder import TransformerDecoder
from keras_nlp.layers.transformer_encoder import TransformerEncoder
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.bart.bart_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty


def bart_kernel_initializer(stddev=0.02):
Expand Down Expand Up @@ -247,3 +251,7 @@ def get_config(self):
@property
def token_embedding(self):
return self.get_layer("token_embedding")

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
63 changes: 63 additions & 0 deletions keras_nlp/models/bart/bart_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BART model preset configurations."""

backbone_presets = {
"bart_base_en": {
"config": {
"vocabulary_size": 50265,
"num_layers": 6,
"num_heads": 12,
"hidden_dim": 768,
"intermediate_dim": 3072,
"dropout": 0.1,
"max_sequence_length": 1024,
},
"preprocessor_config": {},
"description": (
"Base size of BART where case is maintained. "
"Trained on a 160GB English dataset comprising BookCorpus, "
"English Wikipedia and CommonCrawl."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/model.h5",
"weights_hash": "5b59403f0cafafbd89680e0785791163",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/vocab.json",
"vocabulary_hash": "be4d3c6f3f5495426b2c03b334334354",
"merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_base_en/v1/merges.txt",
"merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e",
},
"bart_large_en": {
"config": {
"vocabulary_size": 50265,
"num_layers": 12,
"num_heads": 16,
"hidden_dim": 1024,
"intermediate_dim": 4096,
"dropout": 0.1,
"max_sequence_length": 1024,
},
"preprocessor_config": {},
"description": (
"Large size of BART where case is maintained. "
"Trained on a 160GB English dataset comprising BookCorpus, "
"English Wikipedia and CommonCrawl."
),
"weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/model.h5",
"weights_hash": "6bfe7e591af8c5699ce6f9f18753af9a",
"vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/vocab.json",
"vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596",
"merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/merges.txt",
"merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e",
},
}
139 changes: 139 additions & 0 deletions keras_nlp/models/bart/bart_presets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for loading pretrained model presets."""

import pytest
import tensorflow as tf
from absl.testing import parameterized

from keras_nlp.models.bart.bart_backbone import BartBackbone
from keras_nlp.models.bart.bart_tokenizer import BartTokenizer


@pytest.mark.large
class BartPresetSmokeTest(tf.test.TestCase, parameterized.TestCase):
"""
A smoke test for BART presets we run continuously.

This only tests the smallest weights we have available. Run with:
`pytest keras_nlp/models/bart/bart_presets_test.py --run_large`
"""

def test_tokenizer_output(self):
tokenizer = BartTokenizer.from_preset(
"bart_base_en",
)
outputs = tokenizer("The quick brown fox.")
expected_outputs = [133, 2119, 6219, 23602, 4]
self.assertAllEqual(outputs, expected_outputs)

@parameterized.named_parameters(
("preset_weights", True), ("random_weights", False)
)
def test_backbone_output(self, load_weights):
input_data = {
"encoder_token_ids": tf.constant([[0, 133, 2119, 2]]),
"encoder_padding_mask": tf.constant([[1, 1, 1, 1]]),
"decoder_token_ids": tf.constant([[0, 7199, 14, 2119, 2]]),
"decoder_padding_mask": tf.constant([[1, 1, 1, 1, 1]]),
}
model = BartBackbone.from_preset(
"bart_base_en", load_weights=load_weights
)
outputs = model(input_data)
if load_weights:
encoder_output = outputs["encoder_sequence_output"][0, 0, :5]
expected_encoder_output = [-0.033, 0.013, -0.003, -0.012, -0.002]
decoder_output = outputs["decoder_sequence_output"][0, 0, :5]
expected_decoder_output = [2.516, 2.489, 0.695, 8.057, 1.245]

self.assertAllClose(
encoder_output, expected_encoder_output, atol=0.01, rtol=0.01
)
self.assertAllClose(
decoder_output, expected_decoder_output, atol=0.01, rtol=0.01
)

@parameterized.named_parameters(
("bart_tokenizer", BartTokenizer),
("bart", BartBackbone),
)
def test_preset_docstring(self, cls):
"""Check we did our docstring formatting correctly."""
for name in cls.presets:
self.assertRegex(cls.from_preset.__doc__, name)

@parameterized.named_parameters(
("bart_tokenizer", BartTokenizer),
("bart", BartBackbone),
)
def test_unknown_preset_error(self, cls):
# Not a preset name
with self.assertRaises(ValueError):
cls.from_preset("bart_base_en_clowntown")


@pytest.mark.extra_large
class BartPresetFullTest(tf.test.TestCase, parameterized.TestCase):
"""
Test the full enumeration of our preset.

This tests every BART preset and is only run manually.
Run with:
`pytest keras_nlp/models/bart/bart_presets_test.py --run_extra_large`
"""

@parameterized.named_parameters(
("preset_weights", True), ("random_weights", False)
)
def test_load_bart(self, load_weights):
for preset in BartBackbone.presets:
model = BartBackbone.from_preset(preset, load_weights=load_weights)
input_data = {
"encoder_token_ids": tf.random.uniform(
shape=(1, 1024),
dtype=tf.int64,
maxval=model.vocabulary_size,
),
"encoder_padding_mask": tf.constant(
[1] * 768 + [0] * 256, shape=(1, 1024)
),
"decoder_token_ids": tf.random.uniform(
shape=(1, 1024),
dtype=tf.int64,
maxval=model.vocabulary_size,
),
"decoder_padding_mask": tf.constant(
[1] * 489 + [0] * 535, shape=(1, 1024)
),
}
model(input_data)

def test_load_tokenizers(self):
for preset in BartTokenizer.presets:
tokenizer = BartTokenizer.from_preset(preset)
tokenizer("The quick brown fox.")
120 changes: 120 additions & 0 deletions keras_nlp/models/bart/bart_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BART tokenizer."""

import copy

from tensorflow import keras

from keras_nlp.models.bart.bart_presets import backbone_presets
from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class BartTokenizer(BytePairTokenizer):
"""A BART tokenizer using Byte-Pair Encoding subword segmentation.

This tokenizer class will tokenize raw strings into integer sequences and
is based on `keras_nlp.tokenizers.BytePairTokenizer`. Unlike the
underlying tokenizer, it will check for all special tokens needed by BART
models and provides a `from_preset()` method to automatically download
a matching vocabulary for a BART preset.

This tokenizer does not provide truncation or padding of inputs. It can be
combined with a `keras_nlp.models.BartPreprocessor` layer for input
packing.

If input is a batch of strings (rank > 0), the layer will output a
`tf.RaggedTensor` where the last dimension of the output is ragged.

If input is a scalar string (rank == 0), the layer will output a dense
`tf.Tensor` with static shape `[None]`.

Args:
vocabulary: string or dict, maps token to integer ids. If it is a
string, it should be the file path to a json file.
merges: string or list, contains the merge rule. If it is a string,
it should be the file path to merge rules. The merge rule file
should have one merge rule per line. Every merge rule contains
merge entities separated by a space.

Examples:

Batched inputs.
>>> vocab = {"<s>": 0, "<pad>": 1, "</s>": 2, "<mask>": 3}
>>> vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6}
>>> merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
>>> merges += ["Ġ f", "o x", "Ġf ox"]
>>> tokenizer = keras_nlp.models.RobertaTokenizer(
... vocabulary=vocab, merges=merges
... )
>>> tokenizer(["a quick fox", "a fox quick"])
<tf.RaggedTensor [[4, 5, 6], [4, 6, 5]]>

Unbatched input.
>>> vocab = {"<s>": 0, "<pad>": 1, "</s>": 2, "<mask>": 3}
>>> vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6}
>>> merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
>>> merges += ["Ġ f", "o x", "Ġf ox"]
>>> tokenizer = keras_nlp.models.RobertaTokenizer(
... vocabulary=vocab, merges=merges
... )
>>> tokenizer("a quick fox")
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>

Detokenization.
>>> vocab = {"<s>": 0, "<pad>": 1, "</s>": 2, "<mask>": 3}
>>> vocab = {**vocab, "a": 4, "Ġquick": 5, "Ġfox": 6}
>>> merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
>>> merges += ["Ġ f", "o x", "Ġf ox"]
>>> tokenizer = keras_nlp.models.RobertaTokenizer(
... vocabulary=vocab, merges=merges
... )
>>> tokenizer.detokenize(tokenizer("a quick fox")).numpy().decode('utf-8')
'a quick fox'
"""

def __init__(
self,
vocabulary,
merges,
**kwargs,
):
super().__init__(
vocabulary=vocabulary,
merges=merges,
**kwargs,
)

# Check for necessary special tokens.
start_token = "<s>"
pad_token = "<pad>"
end_token = "</s>"
for token in [start_token, pad_token, end_token]:
if token not in self.get_vocabulary():
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self.start_token_id = self.token_to_id(start_token)
self.pad_token_id = self.token_to_id(pad_token)
self.end_token_id = self.token_to_id(end_token)

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)
Loading