diff --git a/keras_nlp/tokenizers/__init__.py b/keras_nlp/tokenizers/__init__.py index d6eb0b6382..7206410a14 100644 --- a/keras_nlp/tokenizers/__init__.py +++ b/keras_nlp/tokenizers/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.tokenizers.byte_tokenizer import ByteTokenizer from keras_nlp.tokenizers.tokenizer import Tokenizer from keras_nlp.tokenizers.word_piece_tokenizer import WordPieceTokenizer diff --git a/keras_nlp/tokenizers/byte_tokenizer.py b/keras_nlp/tokenizers/byte_tokenizer.py new file mode 100644 index 0000000000..2f142004de --- /dev/null +++ b/keras_nlp/tokenizers/byte_tokenizer.py @@ -0,0 +1,256 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Byte Tokenizer.""" + +from typing import Any +from typing import Dict + +import numpy as np +import tensorflow as tf +import tensorflow_text as tf_text + +from keras_nlp.tokenizers import tokenizer + + +class ByteTokenizer(tokenizer.Tokenizer): + """Raw byte tokenizer. + + This tokenizer is a vocabulary-free tokenizer which will tokenize text as + as raw bytes from [0, 256). + + If input is a batch of strings: + By default, the layer will output a `tf.RaggedTensor` where the last + dimension of the output is ragged. If `sequence_length` is set, the layer + will output a dense `tf.Tensor` where all inputs have been padded or + truncated to `sequence_length`. The output dtype can be controlled via the + `dtype` argument, which should be an integer type + (tf.int16, tf.int32, etc.). + + If input is a scalar string: + There are two cases here. If `sequence_length` is set, the output will be + a dense `tf.Tensor` of shape `[sequence_length]`. Otherwise, the output will + be a dense `tf.Tensor` of shape `[None]`. + + Args: + lowercase: boolean. If True, the input text will be converted to + lowercase before tokenization. + sequence_length: int. If set, the output will be converted to a dense + tensor and padded/trimmed so all outputs are of sequence_length. + normalization_form: string. One of the following values: (None, "NFC", + "NFKC", "NFD", "NFKD"). If set, every UTF-8 string in the input + tensor text will be normalized to the given form before tokenizing. + errors: string. One of ("strict", "replace", "ignore"). Defaults to + "replace". Specifies the `detokenize()` behaviour when an invalid + byte sequence is encountered (same behaviour as + https://www.tensorflow.org/api_docs/python/tf/strings/unicode_transcode). + replacement_char: int. Defaults to 65533. The replacement character to + use when an invalid byte sequence is encountered and when `errors` + is set to "replace" (same behaviour as + https://www.tensorflow.org/api_docs/python/tf/strings/unicode_transcode). + + Examples: + + Basic usage. + >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer() + >>> tokenizer("hello") + + + Ragged outputs. + >>> inputs = tf.constant(["hello", "hi"]) + >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer() + >>> tokenizer(inputs) + + + Dense outputs. + >>> inputs = tf.constant(["hello", "hi"]) + >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer(sequence_length=8) + >>> tokenizer(inputs) + + + Dense outputs. + >>> inputs = tf.constant(["hello"]) + >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer(sequence_length=8) + >>> tokenizer(inputs) + + + Tokenize first, then batch the dataset up. + >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer() + >>> ds = tf.data.Dataset.from_tensor_slices(["hello", "fun"]) + >>> ds = ds.map(tokenizer) + >>> ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(2)) + >>> ds.take(1).get_single_element() + + + Batch up the inputs and then tokenize. + >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer() + >>> ds = tf.data.Dataset.from_tensor_slices(["hello", "fun"]) + >>> ds = ds.batch(2).map(tokenizer) + >>> ds.take(1).get_single_element() + + + Tokenize first, then batch the dataset up (`sequence_length` provided). + >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer(sequence_length=5) + >>> ds = tf.data.Dataset.from_tensor_slices(["hello", "fun"]) + >>> ds = ds.map(tokenizer) + >>> ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(2)) + >>> ds.take(1).get_single_element() + + + Batch up the inputs and then tokenize (`sequence_length` provided). + >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer(sequence_length=5) + >>> ds = tf.data.Dataset.from_tensor_slices(["hello", "fun"]) + >>> ds = ds.batch(2).map(tokenizer) + >>> ds.take(1).get_single_element() + + + Detokenization. + >>> inputs = tf.constant([104, 101, 108, 108, 111], dtype=tf.int32) + >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer() + >>> tokenizer.detokenize(inputs) + + + Detokenization with invalid bytes. + >>> # The 255 below is invalid utf-8. + >>> inputs = tf.constant([104, 101, 255, 108, 108, 111], dtype=tf.int32) + >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer( + ... errors="replace", replacement_char=88) + >>> tokenizer.detokenize(inputs).numpy().decode('utf-8') + 'heXllo' + """ + + def __init__( + self, + lowercase: bool = True, + sequence_length: int = None, + normalization_form: str = None, + errors: str = "replace", + replacement_char: int = 65533, + **kwargs, + ): + # Check dtype and provide a default. + if "dtype" not in kwargs or kwargs["dtype"] is None: + kwargs["dtype"] = tf.int32 + else: + dtype = tf.dtypes.as_dtype(kwargs["dtype"]) + if not dtype.is_integer: + raise ValueError( + "Output dtype must be an integer type. " + f"Received: dtype={dtype}" + ) + + # Check normalization_form. + if normalization_form not in (None, "NFC", "NFKC", "NFD", "NFKD"): + raise ValueError( + '`normalization_form` must be one of None, "NFC", "NFKC", ' + '"NFD", "NFKD". Received: normalization_form=' + f"{normalization_form}" + ) + + # Check errors. + if errors not in ("strict", "replace", "ignore"): + raise ValueError( + '`errors` must be one of "strict", "replace", "ignore" ' + f"Received: errors={errors}" + ) + + super().__init__(**kwargs) + + self.lowercase = lowercase + self.sequence_length = sequence_length + self.normalization_form = normalization_form + self.errors = errors + self.replacement_char = replacement_char + + self._char_lst = tf.constant( + [i.tobytes() for i in np.arange(256, dtype=np.uint8)] + ) + + def vocabulary_size(self) -> int: + """Get the size of the tokenizer vocabulary.""" + return 256 + + def tokenize(self, inputs): + + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + # Optional: Lowercase the input. + if self.lowercase: + inputs = tf_text.case_fold_utf8(inputs) + + # Optional: Normalize unicode. + if self.normalization_form is not None: + inputs = tf_text.normalize_utf8(inputs, self.normalization_form) + + # Tokenize input strings. + tokens = tf.strings.bytes_split(inputs) + tokens = tf.squeeze( + tf.ragged.map_flat_values(tf.io.decode_raw, tokens, tf.uint8), -1 + ) + tokens = tf.cast(tokens, self.compute_dtype) + + # Convert to a dense output if `sequence_length` is set. + if self.sequence_length: + output_shape = tokens.shape.as_list() + output_shape[-1] = self.sequence_length + tokens = tokens.to_tensor(shape=output_shape) + + if scalar_input: + tokens = tf.squeeze(tokens, 0) + return tokens + + def detokenize(self, inputs): + # Remove trailing padding tokens, so that trailing "\x00" bytes don't + # show up in the detokenized output. + inputs = tf.ragged.boolean_mask(inputs, tf.not_equal(inputs, 0)) + + decoded = tf.strings.reduce_join( + tf.gather(self._char_lst, inputs), axis=-1 + ) + + # Handle errors if an invalid byte sequence is encountered. + decoded = tf.strings.unicode_transcode( + decoded, + "UTF-8", + "UTF-8", + errors=self.errors, + replacement_char=self.replacement_char, + ) + return decoded + + def get_config(self) -> Dict[str, Any]: + config = super().get_config() + config.update( + { + "lowercase": self.lowercase, + "sequence_length": self.sequence_length, + "normalization_form": self.normalization_form, + "errors": self.errors, + "replacement_char": self.replacement_char, + } + ) + return config diff --git a/keras_nlp/tokenizers/byte_tokenizer_test.py b/keras_nlp/tokenizers/byte_tokenizer_test.py new file mode 100644 index 0000000000..80b86bc702 --- /dev/null +++ b/keras_nlp/tokenizers/byte_tokenizer_test.py @@ -0,0 +1,266 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.tokenizers.byte_tokenizer import ByteTokenizer + + +class ByteTokenizerTest(tf.test.TestCase): + def test_tokenize(self): + input_data = tf.constant(["hello", "fun", "▀▁▂▃"]) + tokenizer = ByteTokenizer() + call_output = tokenizer(input_data) + tokenize_output = tokenizer.tokenize(input_data) + self.assertIsInstance(call_output, tf.RaggedTensor) + exp_outputs = [ + [104, 101, 108, 108, 111], + [102, 117, 110], + [226, 150, 128, 226, 150, 129, 226, 150, 130, 226, 150, 131], + ] + for i in range(call_output.shape[0]): + self.assertAllEqual(call_output[i], exp_outputs[i]) + self.assertAllEqual(tokenize_output[i], exp_outputs[i]) + + def test_tokenize_scalar(self): + input_data = "hello" + tokenizer = ByteTokenizer() + call_output = tokenizer(input_data) + tokenize_output = tokenizer.tokenize(input_data) + + self.assertAllEqual(call_output, [104, 101, 108, 108, 111]) + self.assertAllEqual(tokenize_output, [104, 101, 108, 108, 111]) + + def test_dense_output(self): + input_data = tf.constant(["hello", "fun", "▀▁▂▃"]) + tokenizer = ByteTokenizer(sequence_length=10) + call_output = tokenizer(input_data) + self.assertIsInstance(call_output, tf.Tensor) + self.assertAllEqual( + call_output, + [ + [104, 101, 108, 108, 111, 0, 0, 0, 0, 0], + [102, 117, 110, 0, 0, 0, 0, 0, 0, 0], + [226, 150, 128, 226, 150, 129, 226, 150, 130, 226], + ], + ) + + def test_detokenize(self): + input_data = tf.ragged.constant( + [ + [104, 101, 108, 108, 111], + [102, 117, 110], + [226, 150, 128, 226, 150, 129, 226, 150, 130, 226, 150, 131], + ] + ) + + tokenizer = ByteTokenizer() + detokenize_output = tokenizer.detokenize(input_data) + self.assertAllEqual(detokenize_output, ["hello", "fun", "▀▁▂▃"]) + + def test_detokenize_replace_error(self): + # 226 is an invalid UTF-8 byte. + input_data = tf.ragged.constant([[104, 101, 226, 150, 108, 108, 111]]) + + tokenizer = ByteTokenizer(errors="replace", replacement_char=341) + detokenize_output = tokenizer.detokenize(input_data) + self.assertAllEqual(detokenize_output, [b"he\xc5\x95llo"]) + + def test_detokenize_ignore_error(self): + input_data = tf.ragged.constant([[104, 101, 226, 150, 108, 108, 111]]) + + tokenizer = ByteTokenizer(errors="ignore") + detokenize_output = tokenizer.detokenize(input_data) + self.assertAllEqual(detokenize_output, [b"hello"]) + + def test_detokenize_strict_error(self): + input_data = tf.ragged.constant([[104, 101, 226, 150, 108, 108, 111]]) + + tokenizer = ByteTokenizer(errors="strict") + with self.assertRaises(tf.errors.InvalidArgumentError): + _ = tokenizer.detokenize(input_data) + + def test_vocab_size(self): + tokenizer = ByteTokenizer() + self.assertEqual(tokenizer.vocabulary_size(), 256) + + def test_lowercase(self): + input_data = tf.constant(["HeLlO wOrLd"]) + tokenizer = ByteTokenizer() + call_output = tokenizer(input_data) + self.assertAllEqual( + call_output, + [[104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]], + ) + + def test_skip_lowercase(self): + input_data = tf.constant(["HeLlO wOrLd"]) + tokenizer = ByteTokenizer(lowercase=False) + call_output = tokenizer(input_data) + self.assertAllEqual( + call_output, [[72, 101, 76, 108, 79, 32, 119, 79, 114, 76, 100]] + ) + + def test_tokenize_first_batch_second(self): + tokenizer = ByteTokenizer() + + ds = tf.data.Dataset.from_tensor_slices( + ["hello", "fun", "▀▁▂▃", "haha"] + ) + ds = ds.map(tokenizer) + ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(4)) + output = ds.take(1).get_single_element() + + exp_output = [ + [104, 101, 108, 108, 111], + [102, 117, 110], + [226, 150, 128, 226, 150, 129, 226, 150, 130, 226, 150, 131], + [104, 97, 104, 97], + ] + for i in range(output.shape[0]): + print(output[i]) + self.assertAllEqual(output[i], exp_output[i]) + + def test_tokenize_first_batch_second_with_sequence_length(self): + tokenizer = ByteTokenizer(sequence_length=10) + + ds = tf.data.Dataset.from_tensor_slices( + ["hello", "fun", "▀▁▂▃", "haha"] + ) + ds = ds.map(tokenizer) + ds = ds.apply(tf.data.experimental.dense_to_ragged_batch(4)) + output = ds.take(1).get_single_element() + + exp_output = [ + [104, 101, 108, 108, 111, 0, 0, 0, 0, 0], + [102, 117, 110, 0, 0, 0, 0, 0, 0, 0], + [226, 150, 128, 226, 150, 129, 226, 150, 130, 226], + [104, 97, 104, 97, 0, 0, 0, 0, 0, 0], + ] + for i in range(output.shape[0]): + print(output[i]) + self.assertAllEqual(output[i], exp_output[i]) + + def test_batch_first_tokenize_second(self): + tokenizer = ByteTokenizer() + + ds = tf.data.Dataset.from_tensor_slices( + ["hello", "fun", "▀▁▂▃", "haha"] + ) + ds = ds.batch(4).map(tokenizer) + output = ds.take(1).get_single_element() + + exp_output = [ + [104, 101, 108, 108, 111], + [102, 117, 110], + [226, 150, 128, 226, 150, 129, 226, 150, 130, 226, 150, 131], + [104, 97, 104, 97], + ] + for i in range(output.shape[0]): + print(output[i]) + self.assertAllEqual(output[i], exp_output[i]) + + def test_batch_first_tokenize_second_with_sequence_length(self): + tokenizer = ByteTokenizer(sequence_length=10) + + ds = tf.data.Dataset.from_tensor_slices( + ["hello", "fun", "▀▁▂▃", "haha"] + ) + ds = ds.batch(4).map(tokenizer) + output = ds.take(1).get_single_element() + + exp_output = [ + [104, 101, 108, 108, 111, 0, 0, 0, 0, 0], + [102, 117, 110, 0, 0, 0, 0, 0, 0, 0], + [226, 150, 128, 226, 150, 129, 226, 150, 130, 226], + [104, 97, 104, 97, 0, 0, 0, 0, 0, 0], + ] + for i in range(output.shape[0]): + print(output[i]) + self.assertAllEqual(output[i], exp_output[i]) + + def test_functional_model(self): + input_data = tf.constant(["hello", "fun", "▀▁▂▃"]) + tokenizer = ByteTokenizer() + inputs = keras.Input(dtype="string", shape=()) + outputs = tokenizer.detokenize(tokenizer.tokenize(inputs)) + model = keras.Model(inputs, outputs) + model_output = model(input_data) + self.assertAllEqual(model_output, ["hello", "fun", "▀▁▂▃"]) + + def test_load_model_with_config(self): + input_data = tf.constant(["hello"]) + + original_tokenizer = ByteTokenizer( + lowercase=False, + sequence_length=8, + normalization_form="NFC", + errors="ignore", + ) + cloned_tokenizer = ByteTokenizer.from_config( + original_tokenizer.get_config() + ) + self.assertAllEqual( + original_tokenizer(input_data), + cloned_tokenizer(input_data), + ) + + decoded_input = [[104, 101, 226, 150, 108, 108, 111]] + self.assertAllEqual( + original_tokenizer.detokenize(decoded_input), + cloned_tokenizer.detokenize(decoded_input), + ) + + def test_config(self): + + tokenizer = ByteTokenizer( + name="byte_tokenizer_config_test", + lowercase=False, + sequence_length=8, + normalization_form="NFC", + errors="ignore", + replacement_char=0, + ) + exp_config = { + "dtype": "int32", + "errors": "ignore", + "lowercase": False, + "name": "byte_tokenizer_config_test", + "normalization_form": "NFC", + "replacement_char": 0, + "sequence_length": 8, + "trainable": True, + } + self.assertEqual(tokenizer.get_config(), exp_config) + + def test_saving(self): + input_data = tf.constant(["this is fun"]) + + tokenizer = ByteTokenizer( + name="byte_tokenizer_config_test", + lowercase=False, + sequence_length=20, + normalization_form="NFKC", + errors="replace", + ) + inputs = keras.Input(dtype="string", shape=()) + outputs = tokenizer(inputs) + model = keras.Model(inputs, outputs) + model.save(self.get_temp_dir()) + restored_model = keras.models.load_model(self.get_temp_dir()) + self.assertAllEqual( + model(input_data), + restored_model(input_data), + )