From 759a405177cede270531f88022bebcac46bf2611 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Fri, 21 Oct 2022 10:51:19 -0700 Subject: [PATCH 1/7] BPE tokenizer (#389) Add more test cases. Co-authored-by: jessechancy add merge file Make cache a tf module Delete testdata address comments address comments fix docstring fix docstring --- keras_nlp/tokenizers/__init__.py | 1 + keras_nlp/tokenizers/byte_pair_tokenizer.py | 547 ++++++++++++++++++ .../tokenizers/byte_pair_tokenizer_test.py | 131 +++++ 3 files changed, 679 insertions(+) create mode 100644 keras_nlp/tokenizers/byte_pair_tokenizer.py create mode 100644 keras_nlp/tokenizers/byte_pair_tokenizer_test.py diff --git a/keras_nlp/tokenizers/__init__.py b/keras_nlp/tokenizers/__init__.py index c033b5061f..1147b03375 100644 --- a/keras_nlp/tokenizers/__init__.py +++ b/keras_nlp/tokenizers/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer from keras_nlp.tokenizers.byte_tokenizer import ByteTokenizer from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer from keras_nlp.tokenizers.tokenizer import Tokenizer diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py new file mode 100644 index 0000000000..def254e8a5 --- /dev/null +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -0,0 +1,547 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Byte-pair encoder implementation. + +This file implements the same logic as openai BPE: +https://github.com/openai/gpt-2/blob/master/src/encoder.py, +but is TF graph compatible. +""" + +import json +from typing import Iterable +from typing import List + +import tensorflow as tf +import tensorflow_text as tf_text +from tensorflow import keras + +from keras_nlp.tokenizers import tokenizer + +# As python and TF handles special spaces differently, we need to +# manually handle special spaces during string split. +SPECIAL_WHITESPACES = r"\x{a0}\x{2009}\x{202f}\x{3000}" + +# String splitting regex pattern. +SPLIT_PATTERN_1 = r"""'s|'t|'re|'ve|'m|'ll|'d + |[\s{special_spaces}]+[\n\r\t\f६{special_spaces}]| ?\p{L}+ + | ?[\p{N}]+| ?[^\s\p{L}\p{N}{special_spaces}]+""".replace( + "{special_spaces}", SPECIAL_WHITESPACES +) + +SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$""" + + +def bytes_to_unicode(): + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + # removes mapping an int to a whitespace character + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + bs = [n.to_bytes(1, "little") for n in bs] + return bs, cs # int to string mapping + + +def remove_strings_from_inputs(tensor, string_to_remove): + """Remove certain strings from input tensor.""" + non_empty_mask = tensor != string_to_remove + flatten_indexes = tf.where(non_empty_mask) + flatten_result = tf.gather_nd(tensor, flatten_indexes) + row_lengths = tf.reduce_sum(tf.cast(non_empty_mask, tf.int64), axis=1) + result = tf.RaggedTensor.from_row_lengths( + values=flatten_result, + row_lengths=row_lengths, + ) + return result + + +def split_strings_for_bpe(inputs): + # We need to recreate the exact behavior of token presplitting in the + # original gpt2 tokenizer which uses a lookahead. As re2 does not + # support lookahead match, we are using an alternative insert a special + # token "६" before leading space of non-space characters and after the + # trailing space, e.g., " keras" will be "६ keras". + inputs = tf.strings.regex_replace( + inputs, rf"( )([^\s{SPECIAL_WHITESPACES}])", r"६\1\2" + ) + inputs = tf.strings.regex_replace( + inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६" + ) + raw_tokens = tf_text.regex_split(inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1) + # Second pass splits out the last whilespace char or "६". + raw_tokens = tf_text.regex_split( + raw_tokens, SPLIT_PATTERN_2, SPLIT_PATTERN_2 + ) + if raw_tokens.shape.rank > 2: + raw_tokens = raw_tokens.merge_dims(1, 2) + return remove_strings_from_inputs(raw_tokens, "६") + + +class BytePairTokenizerCache(tf.Module): + """Cache that stores the encoded result of seen tokens. + + The cache key is string tensor or python strings, and the value is split + tokens joined by whitespace. For example, "dragonfly" => "dragon fly" + + Examples: + ``` + cache = BytePairTokenizerCache() + cache.insert(["butterfly", "dragonfly"], ["but ter fly", "dragon fly"]) + cache.lookup(["butterfly"]) + ``` + """ + + def __init__(self): + # `tf.lookup.experimental.MutableHashTable` does not support string to + # string mapping. So we first convert to string to an integer key, and + # use the integer key to find the value. + self.factors = tf.pow(256, tf.range(0, 8, dtype=tf.int64)) + self.id2value = tf.lookup.experimental.MutableHashTable( + tf.int64, tf.string, "" + ) + + def _get_key(self, keys): + """Get the hash key for given inputs.""" + # `tf.fingerprint` converts token to a array of uint8 of length 8, we + # need to convert it to a uint64. + return tf.squeeze( + tf.matmul( + tf.cast(tf.fingerprint(keys), dtype=tf.int64), + self.factors[:, tf.newaxis], + ), + -1, + ) + + def lookup(self, keys): + """Look up the encoded outputs of given tokens.""" + ids = self._get_key(keys) + result = self.id2value.lookup(ids) + # Ensure output shape for graph mode. + result.set_shape([None]) + return result + + def insert(self, keys, values): + """Insert token <=> encoded outputs pairs.""" + self.id2value.insert(self._get_key(keys), values) + + +def create_static_hashtable(keys, values, default): + return tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer( + tf.convert_to_tensor(keys), + tf.convert_to_tensor(values), + ), + default_value=default, + ) + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class BytePairTokenizer(tokenizer.Tokenizer): + """Bype-pair encoding tokenizer layer. + + This BPE tokenizer provides the same funtionality as official GPT2 + tokenizer. Given the same `vocabulary` which maps tokens to ids, and `merges` + which describes BPE merge rules, it should provide the same output + as openai implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py). + Different from openai, this implementation is graph-compatible, so you can + use it within a tf.data pipeline. + + If input is a batch of strings (rank > 0): + By default, the layer will output a `tf.RaggedTensor` where the last + dimension of the output is ragged. If `sequence_length` is set, the layer + will output a dense `tf.Tensor` where all inputs have been padded or + truncated to `sequence_length`. + If input is a scalar string (rank == 0): + By default, the layer will output a dense `tf.Tensor` with static shape + `[None]`. If `sequence_length` is set, the output will be + a dense `tf.Tensor` of shape `[sequence_length]`. + + Args: + vocabulary: string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, + it should be the file path to merge rules. The merge rule file + should have one merge rule per line. + sequence_length: int, defaults to None. If set, the output will be + padded or truncated to the `sequence_length`. + + Examples: + + Use in-momery vocabulary and merge list. + + >>> vocab = {"butter": 1, "fly": 2} + >>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"] + >>> tokenizer = keras_nlp.tokenizers.BytePairTokenizer(vocab, merge) + >>> tokenizer("butterfly") + + >>> tokenizer(["butterfly"]) + + >>> tokenizer(["butterfly", "butter"]) + + >>> tokenizer = keras_nlp.tokenizers.BytePairTokenizer( + ... vocab, merge, sequence_length=2) + >>> tokenizer(["butterfly", "butter"]) + + + Use hosted vocabluary and merge list. + + ```python + vocab_path = tf.keras.utils.get_file( + "vocab.json", + "https://storage.googleapis.com/keras-nlp/models/roberta_base/vocab.json", + ) + merge_path = tf.keras.utils.get_file( + "merges.txt", + "https://storage.googleapis.com/keras-nlp/models/roberta_base/merges.txt", + ) + tokenizer = BytePairTokenizer( + vocabulary=vocab_path, merges=merge_path + ) + tokenizer("Butterfly is not flying butter!") + ``` + + Detokenize + >>> vocab = {"butter": 1, "fly": 2} + >>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"] + >>> tokenizer = keras_nlp.tokenizers.BytePairTokenizer(vocab, merge) + >>> tokenizer.detokenize([[1, 2]]) + + + """ + + def __init__( + self, + vocabulary, + merges, + sequence_length=None, + **kwargs, + ) -> None: + # Check dtype and provide a default. + if "dtype" not in kwargs or kwargs["dtype"] is None: + kwargs["dtype"] = tf.int32 + else: + dtype = tf.dtypes.as_dtype(kwargs["dtype"]) + if not dtype.is_integer: + raise ValueError( + "Output dtype must be an integer type or a string. " + f"Received: `dtype={dtype}`" + ) + + super().__init__(**kwargs) + + if isinstance(vocabulary, str): + with open(vocabulary, "r") as f: + self.vocabulary = json.load(f) + elif isinstance(vocabulary, dict): + self.vocabulary = vocabulary.copy() + else: + raise ValueError( + "Vocabulary must be an file path or dictionary mapping string " + f"token to int ids. Received: `type(vocabulary)={type(vocabulary)}`." + ) + if isinstance(merges, str): + self.merges = [bp.rstrip() for bp in tf.io.gfile.GFile(merges)] + elif isinstance(merges, Iterable): + self.merges = list(merges) + else: + raise ValueError( + "Merges must be a file path or a list of merge rules. " + f"Received: `type(merges)={type(merges)}`" + ) + self.sequence_length = sequence_length + + # Create byte <=> unicode mapping. This is useful for handling + # whitespace tokens. + byte_list, unicode_list = bytes_to_unicode() + self.byte2unicode = create_static_hashtable( + byte_list, unicode_list, default="" + ) + self.unicode2byte = create_static_hashtable( + unicode_list, byte_list, default="" + ) + + self.cache = BytePairTokenizerCache() + + # Create mapping between string tokens to int ids, and vice versa. + byte_pairs = [x[0] for x in self.vocabulary.items()] + byte_pair_encoding_indices = [x[1] for x in self.vocabulary.items()] + self.token_to_id_map = create_static_hashtable( + byte_pairs, + byte_pair_encoding_indices, + default=-1, + ) + self.id_to_token_map = create_static_hashtable( + byte_pair_encoding_indices, + byte_pairs, + default="", + ) + + # Create ranking of merge rules, this is the same as order of merge + # pairs in `self.merges`. + self.merge_ranks_lookup_default = len(self.merges) + 1 + self.merge_ranks = create_static_hashtable( + self.merges, + list(range(len(self.merges))), + default=self.merge_ranks_lookup_default, + ) + + def get_vocabulary(self) -> List[str]: + """Get the tokenizer vocabulary as a list of strings tokens.""" + return self.vocabulary.keys() + + def vocabulary_size(self) -> int: + """Get the size of the tokenizer vocabulary.""" + return len(self.vocabulary) + + def id_to_token(self, id: int) -> str: + """Convert an integer id to a string token.""" + # This will be slow, but keep memory usage down compared to building a + # dict. Assuming the main use case is looking up a few special tokens + # early in the vocab, this should be fine. + + keys = self.get_vocabulary() + for token in keys: + if self.vocabulary[token] == id: + return token + return None + + def token_to_id(self, token: str) -> int: + """Convert a string token to an integer id.""" + return self.vocabulary[token] + + def get_config(self): + config = super().get_config() + config.update( + { + # Ideally vocabulary and merge list would be saved as plain text + # assets in the saved model. We have no good way to support + # this currently, so we save the vocabulary in the config. + "vocabulary": self.vocabulary, + "merges": self.merges, + "sequence_length": self.sequence_length, + } + ) + return config + + @tf.function + def _bpe_merge_one_step(self, words, mask): + """Perform one step of byte-pair merge.""" + # Get all word pairs. + first, second = words[:, :-1], words[:, 1:] + + # Mask empty. + non_empty_mask = second.nested_row_lengths()[0] != 0 + mask = mask & non_empty_mask + if not tf.reduce_any(mask): + return [words, mask] + non_empty_indices = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask) + filterd_first = tf.ragged.boolean_mask(first, mask) + filtered_second = tf.ragged.boolean_mask(second, mask) + + # Get byte pair ranking in merge rules. + pairs = tf.strings.join([filterd_first, filtered_second], separator=" ") + pair_rank = self.merge_ranks.lookup(pairs) + + # Get BPE pair ranks. + min_pair_rank = tf.reduce_min(pair_rank, axis=1) + pair_found_mask = min_pair_rank != self.merge_ranks_lookup_default + + # Tokens that cannot be further merged are marked as finished. + mask = tf.tensor_scatter_nd_update( + mask, tf.expand_dims(non_empty_indices, axis=1), pair_found_mask + ) + if not tf.math.reduce_any(mask): + return [words, mask] + + masked_pair_rank = tf.ragged.boolean_mask(pair_rank, pair_found_mask) + min_pair_rank_indices = tf.math.argmin( + masked_pair_rank.to_tensor(self.merge_ranks_lookup_default), axis=1 + ) + + # Get words and pairs to process. + unfinished_words = tf.ragged.boolean_mask(words, mask) + + pair_left = tf.gather( + unfinished_words, min_pair_rank_indices, batch_dims=1 + ) + pair_right = tf.gather( + unfinished_words, min_pair_rank_indices + 1, batch_dims=1 + ) + + merged_pairs = tf.strings.join([pair_left, pair_right]) + empty_strs = tf.fill(tf.shape(merged_pairs), "") + + unfinished_word_indices = tf.cast( + tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask), dtype=tf.int64 + ) + merged_pair_indices = tf.concat( + [ + unfinished_word_indices[:, tf.newaxis], + min_pair_rank_indices[:, tf.newaxis], + ], + axis=1, + ) + empty_string_indices = tf.concat( + [ + unfinished_word_indices[:, tf.newaxis], + min_pair_rank_indices[:, tf.newaxis] + 1, + ], + axis=1, + ) + + tensor_words = words.to_tensor(default_value="") + tensor_words = tf.tensor_scatter_nd_update( + tensor_words, + merged_pair_indices, + merged_pairs, + ) + + words = tf.tensor_scatter_nd_update( + tensor_words, + empty_string_indices, + empty_strs, + ) + # Remove empty strings. + words = remove_strings_from_inputs(words, "") + return [words, mask] + + def _bpe_merge(self, inputs): + """Perform byte-pair merge for each word in the inputs.""" + num_words = tf.shape(inputs)[0] + + # Merge bytes. + def loop_condition(_, mask): + return tf.math.reduce_any(mask) + + initial_mask = tf.fill((num_words,), True) + merged_words, _ = tf.while_loop( + loop_condition, + self._bpe_merge_one_step, + loop_vars=[ + inputs, + initial_mask, + ], + shape_invariants=[ + tf.TensorShape([None, None]), + tf.TensorShape([None]), + ], + ) + return merged_words + + def tokenize(self, inputs): + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + raw_tokens = split_strings_for_bpe(inputs) + token_row_splits = raw_tokens.row_splits + flat_tokens = raw_tokens.flat_values + + # Check cache. + cache_lookup = self.cache.lookup(flat_tokens) + cache_mask = cache_lookup == "" + + has_unseen_words = tf.math.reduce_any( + (cache_lookup == "") & (flat_tokens != "") + ) + + def process_unseen_tokens(): + unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) + self._bpe_merge_and_update_cache(unseen_tokens) + return self.cache.lookup(flat_tokens) + + # If `has_unseen_words == True`, it means not all tokens are in cache, + # we will process the unseen tokens. Otherwise return the cache lookup. + tokenized_words = tf.cond( + has_unseen_words, + process_unseen_tokens, + lambda: cache_lookup, + ) + + # Encode merged tokens. + tokenized_words = tf.strings.split(tokenized_words, sep=" ") + encoding = self.token_to_id_map.lookup(tokenized_words) + + # Unflatten to match input. + encoding = tf.RaggedTensor.from_row_splits( + encoding.flat_values, + tf.gather(encoding.row_splits, token_row_splits), + ) + + # Convert to a dense output if `sequence_length` is set. + if self.sequence_length: + output_shape = encoding.shape.as_list() + output_shape[-1] = self.sequence_length + encoding = encoding.to_tensor(shape=output_shape) + + # Convert to a dense output if input in scalar + if scalar_input: + encoding = tf.squeeze(encoding, 0) + tf.ensure_shape(encoding, shape=[self.sequence_length]) + + return encoding + + def detokenize(self, inputs): + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + unicode_text = tf.strings.reduce_join( + self.id_to_token_map.lookup(inputs), axis=1 + ) + split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8") + byte_text = tf.strings.reduce_join( + self.unicode2byte.lookup(split_unicode_text) + ) + + if not scalar_input: + byte_text = tf.expand_dims(byte_text, 0) + + return byte_text + + def _transform_bytes(self, tokens): + """Map token bytes to unicode using `byte2unicode`.""" + split_bytes = tf.strings.bytes_split(tokens) + split_unicode = self.byte2unicode.lookup(split_bytes) + return split_unicode + + def _bpe_merge_and_update_cache(self, tokens): + """Process unseen tokens and add to cache.""" + words = self._transform_bytes(tokens) + tokenized_words = self._bpe_merge(words) + + # For each word, join all its token by a whitespace, + # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. + tokenized_words = tf.strings.reduce_join( + tokenized_words, axis=1, separator=" " + ) + self.cache.insert(tokens, tokenized_words) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py new file mode 100644 index 0000000000..243891e536 --- /dev/null +++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py @@ -0,0 +1,131 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer + +VOCAB_PATH = keras.utils.get_file( + None, + "https://storage.googleapis.com/keras-nlp/models/roberta_base/vocab.json", +) +MERGE_PATH = keras.utils.get_file( + None, + "https://storage.googleapis.com/keras-nlp/models/roberta_base/merges.txt", +) + + +@pytest.mark.slow +class BytePairTokenizerTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + + self.tokenizer = BytePairTokenizer( + vocabulary=VOCAB_PATH, merges=MERGE_PATH + ) + + def test_tokenize_list_input(self): + input_data = ["brown.", "black."] + call_output = self.tokenizer(input_data) + tokenize_output = self.tokenizer.tokenize(input_data) + expected = tf.ragged.constant([[31876, 4], [14178, 4]]) + self.assertAllEqual(call_output, expected) + self.assertAllEqual(tokenize_output, expected) + + input_data = tf.convert_to_tensor(["brown.", "black."]) + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, expected) + + def test_tokenize_scalar_input(self): + input_data = "brown." + encoded = self.tokenizer.tokenize(input_data) + self.assertAllEqual(encoded, [31876, 4]) + + def test_detokenize(self): + input_data = ["brown."] + encoded = self.tokenizer.tokenize(input_data) + decoded = self.tokenizer.detokenize(encoded) + self.assertAllEqual(input_data, decoded) + + def test_whitespace_split(self): + input_data = "\n\n\n s" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [50140, 50118, 1437, 579]) + + input_data = " \n\n\ns" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [1437, 1437, 50140, 50118, 29]) + + def test_special_whitespace(self): + input_data = "\xa0 \xa0 \x3000 s" + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, [50141, 50143, 12096, 579]) + + def test_cjk_input(self): + input_data = "素晴らしい!芭比Q啦~" + # Black formats long list by one element per line, which is bad to read. + expected = [36714, 20024, 21402, 37127, 27, 20024, 48945, 47918] + expected += [47780, 43251, 4394, 10172, 36484, 27969, 12410, 37127] + expected += [10965, 10674, 1864, 42393, 15722, 18164, 43251, 10809] + expected += [17772] + encoded = self.tokenizer(input_data) + self.assertAllEqual(encoded, expected) + + def test_tokenize_with_tf_data(self): + data = [ + "I am just a test string", + "I am also a test string", + "I am still a test string", + "me too", + "I am not a test string (joking)", + "You guys should add punctuation!", + "Period matters!", + ] + ds = tf.data.Dataset.from_tensor_slices(data) + ds = ds.batch(2).map(self.tokenizer) + encoded = next(iter(ds)) + expected = tf.ragged.constant( + [[100, 524, 95, 10, 1296, 6755], [100, 524, 67, 10, 1296, 6755]] + ) + self.assertAllEqual(encoded, expected) + + def test_config(self): + input_data = ["the quick brown whale."] + cloned_tokenizer = BytePairTokenizer.from_config( + self.tokenizer.get_config() + ) + self.assertAllEqual( + self.tokenizer(input_data), + cloned_tokenizer(input_data), + ) + + @parameterized.named_parameters(("tf_format", "tf"), ("h5_format", "h5")) + def test_saving(self, format): + input_data = tf.constant(["the quick brown whale."]) + tokenizer = self.tokenizer + inputs = keras.Input(dtype="string", shape=()) + outputs = tokenizer(inputs) + model = keras.Model(inputs, outputs) + path = os.path.join(self.get_temp_dir(), "model") + model.save(path, save_format=format) + restored_model = keras.models.load_model(path) + self.assertAllEqual( + model(input_data), + restored_model(input_data), + ) From 2ebf24f83868eea529856ef180b28e601bc2419a Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Wed, 26 Oct 2022 15:08:25 -0700 Subject: [PATCH 2/7] Fix byte pair detokenization of 2d arrays (#423) Before this fix, the detokenize function would squish everything down into a single string. So it would not preserve the structure of what was passed in. --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 7 ++----- keras_nlp/tokenizers/byte_pair_tokenizer_test.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index def254e8a5..0666cdb60e 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -516,16 +516,13 @@ def detokenize(self, inputs): inputs = tf.expand_dims(inputs, 0) unicode_text = tf.strings.reduce_join( - self.id_to_token_map.lookup(inputs), axis=1 + self.id_to_token_map.lookup(inputs), axis=-1 ) split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8") byte_text = tf.strings.reduce_join( - self.unicode2byte.lookup(split_unicode_text) + self.unicode2byte.lookup(split_unicode_text), axis=-1 ) - if not scalar_input: - byte_text = tf.expand_dims(byte_text, 0) - return byte_text def _transform_bytes(self, tokens): diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py index 243891e536..6c54dfcae4 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py @@ -57,8 +57,14 @@ def test_tokenize_scalar_input(self): encoded = self.tokenizer.tokenize(input_data) self.assertAllEqual(encoded, [31876, 4]) - def test_detokenize(self): - input_data = ["brown."] + def test_detokenize_scalar_input(self): + input_data = ["quick brown fox."] + encoded = self.tokenizer.tokenize(input_data) + decoded = self.tokenizer.detokenize(encoded) + self.assertAllEqual(input_data, decoded) + + def test_detokenize_list_input(self): + input_data = ["quick brown fox.", "slow black bear."] encoded = self.tokenizer.tokenize(input_data) decoded = self.tokenizer.detokenize(encoded) self.assertAllEqual(input_data, decoded) From 8ecb0021f9a99108dd01f503e2edc8e5420cee0b Mon Sep 17 00:00:00 2001 From: Abheesht Date: Tue, 1 Nov 2022 00:56:29 +0530 Subject: [PATCH 3/7] Support String Output for BytePairTokenizer (#438) * Support string output for BytePairTokenizer * Add unit test * Minor edit --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 37 ++++++++++--------- .../tokenizers/byte_pair_tokenizer_test.py | 14 +++++++ 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 0666cdb60e..fcba55f1bc 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Byte-pair encoder implementation. +"""Byte-pair encoder implementation. This file implements the same logic as openai BPE: https://github.com/openai/gpt-2/blob/master/src/encoder.py, @@ -159,12 +159,12 @@ def create_static_hashtable(keys, values, default): class BytePairTokenizer(tokenizer.Tokenizer): """Bype-pair encoding tokenizer layer. - This BPE tokenizer provides the same funtionality as official GPT2 + This BPE tokenizer provides the same functionality as the official GPT-2 tokenizer. Given the same `vocabulary` which maps tokens to ids, and `merges` which describes BPE merge rules, it should provide the same output - as openai implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py). - Different from openai, this implementation is graph-compatible, so you can - use it within a tf.data pipeline. + as OpenAI implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py). + Different from OpenAI, this implementation is graph-compatible, so you can + use it within a `tf.data` pipeline. If input is a batch of strings (rank > 0): By default, the layer will output a `tf.RaggedTensor` where the last @@ -187,7 +187,7 @@ class BytePairTokenizer(tokenizer.Tokenizer): Examples: - Use in-momery vocabulary and merge list. + Use in-memory vocabulary and merge list. >>> vocab = {"butter": 1, "fly": 2} >>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"] @@ -244,7 +244,7 @@ def __init__( kwargs["dtype"] = tf.int32 else: dtype = tf.dtypes.as_dtype(kwargs["dtype"]) - if not dtype.is_integer: + if not dtype.is_integer and dtype != tf.string: raise ValueError( "Output dtype must be an integer type or a string. " f"Received: `dtype={dtype}`" @@ -484,28 +484,29 @@ def process_unseen_tokens(): lambda: cache_lookup, ) - # Encode merged tokens. - tokenized_words = tf.strings.split(tokenized_words, sep=" ") - encoding = self.token_to_id_map.lookup(tokenized_words) + tokens = tf.strings.split(tokenized_words, sep=" ") + if self.compute_dtype != tf.string: + # Encode merged tokens. + tokens = self.token_to_id_map.lookup(tokens) # Unflatten to match input. - encoding = tf.RaggedTensor.from_row_splits( - encoding.flat_values, - tf.gather(encoding.row_splits, token_row_splits), + tokens = tf.RaggedTensor.from_row_splits( + tokens.flat_values, + tf.gather(tokens.row_splits, token_row_splits), ) # Convert to a dense output if `sequence_length` is set. if self.sequence_length: - output_shape = encoding.shape.as_list() + output_shape = tokens.shape.as_list() output_shape[-1] = self.sequence_length - encoding = encoding.to_tensor(shape=output_shape) + tokens = tokens.to_tensor(shape=output_shape) # Convert to a dense output if input in scalar if scalar_input: - encoding = tf.squeeze(encoding, 0) - tf.ensure_shape(encoding, shape=[self.sequence_length]) + tokens = tf.squeeze(tokens, 0) + tf.ensure_shape(tokens, shape=[self.sequence_length]) - return encoding + return tokens def detokenize(self, inputs): if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py index 6c54dfcae4..5d5c5f2622 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py @@ -52,6 +52,20 @@ def test_tokenize_list_input(self): encoded = self.tokenizer(input_data) self.assertAllEqual(encoded, expected) + def test_tokenize_string_output(self): + input_data = ["quick brown fox.", "slow black bear."] + tokenizer = BytePairTokenizer( + vocabulary=VOCAB_PATH, merges=MERGE_PATH, dtype=tf.string + ) + call_output = tokenizer(input_data) + expected = tf.ragged.constant( + [ + ["quick", "Ġbrown", "Ġfox", "."], + ["slow", "Ġblack", "Ġbear", "."], + ] + ) + self.assertAllEqual(call_output, expected) + def test_tokenize_scalar_input(self): input_data = "brown." encoded = self.tokenizer.tokenize(input_data) From 25d0fe3ed51bfc268f88d70b0a02124a119bbe77 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 2 Nov 2022 17:31:38 -0700 Subject: [PATCH 4/7] initial commit (#440) --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index fcba55f1bc..3e094b5d27 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -187,8 +187,7 @@ class BytePairTokenizer(tokenizer.Tokenizer): Examples: - Use in-memory vocabulary and merge list. - + Tokenize >>> vocab = {"butter": 1, "fly": 2} >>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"] >>> tokenizer = keras_nlp.tokenizers.BytePairTokenizer(vocab, merge) @@ -205,23 +204,6 @@ class BytePairTokenizer(tokenizer.Tokenizer): array([[1, 2], [1, 0]], dtype=int32)> - Use hosted vocabluary and merge list. - - ```python - vocab_path = tf.keras.utils.get_file( - "vocab.json", - "https://storage.googleapis.com/keras-nlp/models/roberta_base/vocab.json", - ) - merge_path = tf.keras.utils.get_file( - "merges.txt", - "https://storage.googleapis.com/keras-nlp/models/roberta_base/merges.txt", - ) - tokenizer = BytePairTokenizer( - vocabulary=vocab_path, merges=merge_path - ) - tokenizer("Butterfly is not flying butter!") - ``` - Detokenize >>> vocab = {"butter": 1, "fly": 2} >>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"] @@ -229,7 +211,6 @@ class BytePairTokenizer(tokenizer.Tokenizer): >>> tokenizer.detokenize([[1, 2]]) - """ def __init__( From 45376d7de6d985ce348cd6097b628bca83fe27f5 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 11 Oct 2022 16:51:49 -0700 Subject: [PATCH 5/7] Stop running CI on Windows (#386) Windows nightlies have been broken for a few days now. Probably related, core Tensorflow is dropping native Windows support for anything but tensorflow-cpu (which we don't install by default) as of the upcoming tf release (2.11). Both core tensorflow and our development guide recommend WSL2. I think the simplest solution here is to stop running windows CI and continue recommending WSL2 for running on windows. --- .github/workflows/actions.yml | 5 +---- .github/workflows/nightly.yml | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index cd852986b2..13f2498d55 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -8,10 +8,7 @@ on: jobs: build: name: Run tests - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [windows-latest, ubuntu-latest] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Set up Python 3.7 diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 7ff9404820..ee58104284 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -9,10 +9,7 @@ on: jobs: build: name: Run tests - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [windows-latest, ubuntu-latest] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Set up Python 3.7 From 6e4cbefad7709263e7574c08a058140cf96771b8 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Wed, 12 Oct 2022 13:14:17 -0700 Subject: [PATCH 6/7] Improve MacOS support and pin tensorflow version during testing (#383) * Improve MacOS support * Conditionally import tensorflow_text everywhere * Use requirements files for continuous testing * Fix logs * Bug fixes and improvement for linux testing * Typo fix * Address review comments --- .github/workflows/actions.yml | 8 +- .github/workflows/nightly.yml | 8 +- CONTRIBUTING.md | 104 +++++++++++++----- examples/bert/README.md | 5 - .../integration_tests/basic_usage_test.py | 10 +- keras_nlp/layers/mlm_mask_generator.py | 10 +- keras_nlp/layers/multi_segment_packer.py | 10 +- keras_nlp/metrics/rouge_base.py | 6 +- keras_nlp/tokenizers/byte_tokenizer.py | 9 +- .../tokenizers/sentence_piece_tokenizer.py | 11 +- .../tokenizers/unicode_character_tokenizer.py | 9 +- keras_nlp/tokenizers/word_piece_tokenizer.py | 9 +- .../utils/{tensor_utils.py => tf_utils.py} | 14 +++ ...{tensor_utils_test.py => tf_utils_test.py} | 2 +- requirements-common.txt | 15 +++ requirements-macos-m1.txt | 15 +++ requirements-nightly.txt | 6 + requirements.txt | 6 + setup.py | 6 +- 19 files changed, 205 insertions(+), 58 deletions(-) rename keras_nlp/utils/{tensor_utils.py => tf_utils.py} (81%) rename keras_nlp/utils/{tensor_utils_test.py => tf_utils_test.py} (95%) create mode 100644 requirements-common.txt create mode 100644 requirements-macos-m1.txt create mode 100644 requirements-nightly.txt create mode 100644 requirements.txt diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 13f2498d55..13c295a658 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -29,8 +29,8 @@ jobs: ${{ runner.os }}-pip- - name: Install dependencies run: | - pip install tensorflow - pip install -e ".[tests]" --progress-bar off --upgrade + pip install -r requirements.txt --progress-bar off + pip install -e "." --progress-bar off - name: Test with pytest run: | pytest --cov=keras_nlp --cov-report xml:coverage.xml @@ -57,7 +57,7 @@ jobs: ${{ runner.os }}-pip- - name: Install dependencies run: | - pip install tensorflow - pip install -e ".[tests]" --progress-bar off --upgrade + pip install -r requirements.txt --progress-bar off + pip install -e "." --progress-bar off - name: Lint run: bash shell/lint.sh diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index ee58104284..e2b843e56c 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -30,12 +30,8 @@ jobs: ${{ runner.os }}-pip- - name: Install dependencies run: | - pip install -e ".[tests]" --progress-bar off --upgrade - pip uninstall keras -y - pip uninstall tensorflow -y - pip uninstall tensorflow_text -y - pip install tf-nightly --progress-bar off --upgrade - pip install tensorflow-text-nightly --progress-bar off --upgrade + pip install -r requirements-nightly.txt --progress-bar off + pip install -e "." --progress-bar off - name: Test with pytest run: | pytest --cov=keras_nlp --cov-report xml:coverage.xml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5e1e4b919d..1d75864cb8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -84,25 +84,90 @@ Once the pull request is approved, a team member will take care of merging. Python 3.7 or later is required. Setting up your KerasNLP development environment requires you to fork the -KerasNLP repository, clone the repository, create a virtual environment, and -install dependencies. - -You can achieve this by running the following commands: +KerasNLP repository and clone it locally. With the +[GitHub CLI](https://github.com/cli/cli) installed, you can do this as follows: ```shell gh repo fork keras-team/keras-nlp --clone --remote cd keras-nlp -python -m venv ~/keras-nlp-venv -source ~/keras-nlp-venv/bin/activate -pip install -e ".[tests]" ``` -The first line relies on having an installation of -[the GitHub CLI](https://github.com/cli/cli). +Next we must setup a python environment with the correct dependencies. We +recommend using `conda` to install tensorflow dependencies (such as CUDA), and +`pip` to install python packages from PyPI. The exact method will depend on your +OS. + +### Linux (recommended) + +To setup a complete environment with TensorFlow, a local install of keras-nlp, +and all development tools, run the following or adapt it to suit your needs. + +```shell +# Create and activate conda environment. +conda create -n keras-nlp python=3.9 +conda activate keras-nlp + +# The following can be omitted if GPU support is not required. +conda install -c conda-forge cudatoolkit-dev=11.2 cudnn=8.1.0 +mkdir -p $CONDA_PREFIX/etc/conda/activate.d/ +echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh +echo 'export XLA_FLAGS=--xla_gpu_cuda_data_dir=$CONDA_PREFIX/' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh +source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh + +# Install dependencies. +python -m pip install --upgrade pip +python -m pip install -r requirements.txt +python -m pip install -e "." +``` + +### MacOS + +⚠️⚠️⚠️ MacOS binaries are for the M1 architecture are not currently available from +official sources. You can try experimental development workflow leveraging the +[tensorflow metal plugin](https://developer.apple.com/metal/tensorflow-plugin/) +and a [community maintained build](https://github.com/sun1638650145/Libraries-and-Extensions-for-TensorFlow-for-Apple-Silicon) +of `tensorflow-text`. These binaries are not provided by Google, so proceed at +your own risk. + +#### Experimental instructions for Arm (M1) + +```shell +# Create and activate conda environment. +conda create -n keras-nlp python=3.9 +conda activate keras-nlp + +# Install dependencies. +conda install -c apple tensorflow-deps=2.9 +python -m pip install --upgrade pip +python -m pip install -r requirements-macos-m1.txt +python -m pip install -e "." +``` -Following these commands you should be able to run the tests using -`pytest keras_nlp`. Please report any issues running tests following these -steps. +#### Instructions for x86 (Intel) + +```shell +# Create and activate conda environment. +conda create -n keras-nlp python=3.9 +conda activate keras-nlp + +# Install dependencies. +python -m pip install --upgrade pip +python -m pip install -r requirements.txt +python -m pip install -e "." +``` + +### Windows + +For the best experience developing on windows, please install +[WSL](https://learn.microsoft.com/en-us/windows/wsl/install), and proceed with +the linux installation instruction above. + +To run the format and lint scripts, make sure you clone the repo with Linux +style line endings and change any line separator settings in your editor. +This is automatically done if you clone using git inside WSL. + +Note that will not support Windows Shell/PowerShell for any scripts in this +repository. ## Testing changes @@ -143,18 +208,3 @@ the following commands manually every time you want to format your code: If after running these the CI flow is still failing, try updating `flake8`, `isort` and `black`. This can be done by running `pip install --upgrade black`, `pip install --upgrade flake8`, and `pip install --upgrade isort`. - -## Developing on Windows - -For Windows development, we recommend using WSL (Windows Subsystem for Linux), -so you can run the shell scripts in this repository. We will not support -Windows Shell/PowerShell. You can refer -[to these instructions](https://docs.microsoft.com/en-us/windows/wsl/install) -for WSL installation. - -Note that if you are using Windows Subsystem for Linux (WSL), make sure you -clone the repo with Linux style LF line endings and change the default setting -for line separator in your Text Editor before running the format -or lint scripts. This is automatically done if you clone using git inside WSL. -If there is conflict due to the line endings you might see an error -like - `: invalid option`. diff --git a/examples/bert/README.md b/examples/bert/README.md index 3663763096..23098408d3 100644 --- a/examples/bert/README.md +++ b/examples/bert/README.md @@ -16,11 +16,6 @@ need to be trained for much longer on a much larger dataset. OUTPUT_DIR=~/bert_test_output DATA_URL=https://storage.googleapis.com/tensorflow/keras-nlp/examples/bert -# Create a virtual env and install dependencies. -mkdir $OUTPUT_DIR -python3 -m venv $OUTPUT_DIR/env && source $OUTPUT_DIR/env/bin/activate -pip install -e ".[tests,examples]" - # Download example data. wget ${DATA_URL}/bert_vocab_uncased.txt -O $OUTPUT_DIR/bert_vocab_uncased.txt wget ${DATA_URL}/wiki_example_data.txt -O $OUTPUT_DIR/wiki_example_data.txt diff --git a/keras_nlp/integration_tests/basic_usage_test.py b/keras_nlp/integration_tests/basic_usage_test.py index e9a2f85717..c6667b271a 100644 --- a/keras_nlp/integration_tests/basic_usage_test.py +++ b/keras_nlp/integration_tests/basic_usage_test.py @@ -13,13 +13,17 @@ # limitations under the License. import tensorflow as tf +from absl.testing import parameterized from tensorflow import keras import keras_nlp -class BasicUsageTest(tf.test.TestCase): - def test_quick_start(self): +class BasicUsageTest(tf.test.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_quick_start(self, jit_compile): """This matches the quick start example in our base README.""" # Tokenize some inputs with a binary label. @@ -47,7 +51,7 @@ def test_quick_start(self): model = keras.Model(inputs, outputs) # Run a single batch of gradient descent. - model.compile(loss="binary_crossentropy", jit_compile=True) + model.compile(loss="binary_crossentropy", jit_compile=jit_compile) loss = model.train_on_batch(x, y) # Make sure we have a valid loss. diff --git a/keras_nlp/layers/mlm_mask_generator.py b/keras_nlp/layers/mlm_mask_generator.py index 6576dc50bf..fb668adc51 100644 --- a/keras_nlp/layers/mlm_mask_generator.py +++ b/keras_nlp/layers/mlm_mask_generator.py @@ -13,9 +13,15 @@ # limitations under the License. import tensorflow as tf -import tensorflow_text as tf_text from tensorflow import keras +from keras_nlp.utils.tf_utils import assert_tf_text_installed + +try: + import tensorflow_text as tf_text +except ImportError: + tf_text = None + class MLMMaskGenerator(keras.layers.Layer): """Layer that applies language model masking. @@ -96,6 +102,8 @@ def __init__( random_token_rate=0.1, **kwargs, ): + assert_tf_text_installed(self.__class__.__name__) + super().__init__(**kwargs) self.vocabulary_size = vocabulary_size self.unselectable_token_ids = unselectable_token_ids diff --git a/keras_nlp/layers/multi_segment_packer.py b/keras_nlp/layers/multi_segment_packer.py index 2f3c2eaf56..ff03a9f114 100644 --- a/keras_nlp/layers/multi_segment_packer.py +++ b/keras_nlp/layers/multi_segment_packer.py @@ -15,9 +15,15 @@ """BERT token packing layer.""" import tensorflow as tf -import tensorflow_text as tf_text from tensorflow import keras +from keras_nlp.utils.tf_utils import assert_tf_text_installed + +try: + import tensorflow_text as tf_text +except ImportError: + tf_text = None + class MultiSegmentPacker(keras.layers.Layer): """Packs multiple sequences into a single fixed width model input. @@ -106,6 +112,8 @@ def __init__( truncator="round_robin", **kwargs, ): + assert_tf_text_installed(self.__class__.__name__) + super().__init__(**kwargs) self.sequence_length = sequence_length if truncator not in ("round_robin", "waterfall"): diff --git a/keras_nlp/metrics/rouge_base.py b/keras_nlp/metrics/rouge_base.py index 22d4adf3b8..ffa5a242fc 100644 --- a/keras_nlp/metrics/rouge_base.py +++ b/keras_nlp/metrics/rouge_base.py @@ -20,7 +20,7 @@ import tensorflow as tf from tensorflow import keras -from keras_nlp.utils.tensor_utils import tensor_to_string_list +from keras_nlp.utils.tf_utils import tensor_to_string_list try: import rouge_score @@ -62,8 +62,8 @@ def __init__( if rouge_score is None: raise ImportError( - "ROUGE metric requires the `rouge_score` package. " - "Please install it with `pip install rouge-score`." + f"{self.__class__.__name__} requires the `rouge_score` " + "package. Please install it with `pip install rouge-score`." ) if not tf.as_dtype(self.dtype).is_floating: diff --git a/keras_nlp/tokenizers/byte_tokenizer.py b/keras_nlp/tokenizers/byte_tokenizer.py index eb0057a8dc..1b411ad6a6 100644 --- a/keras_nlp/tokenizers/byte_tokenizer.py +++ b/keras_nlp/tokenizers/byte_tokenizer.py @@ -16,9 +16,14 @@ import numpy as np import tensorflow as tf -import tensorflow_text as tf_text from keras_nlp.tokenizers import tokenizer +from keras_nlp.utils.tf_utils import assert_tf_text_installed + +try: + import tensorflow_text as tf_text +except ImportError: + tf_text = None class ByteTokenizer(tokenizer.Tokenizer): @@ -150,6 +155,8 @@ def __init__( replacement_char: int = 65533, **kwargs, ): + assert_tf_text_installed(self.__class__.__name__) + # Check dtype and provide a default. if "dtype" not in kwargs or kwargs["dtype"] is None: kwargs["dtype"] = tf.int32 diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer.py index d5828892b1..b33be0b8fe 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer.py @@ -17,10 +17,15 @@ from typing import List import tensorflow as tf -import tensorflow_text as tf_text from keras_nlp.tokenizers import tokenizer -from keras_nlp.utils.tensor_utils import tensor_to_string_list +from keras_nlp.utils.tf_utils import assert_tf_text_installed +from keras_nlp.utils.tf_utils import tensor_to_string_list + +try: + import tensorflow_text as tf_text +except ImportError: + tf_text = None class SentencePieceTokenizer(tokenizer.Tokenizer): @@ -96,6 +101,8 @@ def __init__( sequence_length: int = None, **kwargs, ) -> None: + assert_tf_text_installed(self.__class__.__name__) + # Check dtype and provide a default. if "dtype" not in kwargs or kwargs["dtype"] is None: kwargs["dtype"] = tf.int32 diff --git a/keras_nlp/tokenizers/unicode_character_tokenizer.py b/keras_nlp/tokenizers/unicode_character_tokenizer.py index d23763dc89..0e25e843a2 100644 --- a/keras_nlp/tokenizers/unicode_character_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_character_tokenizer.py @@ -13,9 +13,14 @@ # limitations under the License. import tensorflow as tf -import tensorflow_text as tf_text from keras_nlp.tokenizers import tokenizer +from keras_nlp.utils.tf_utils import assert_tf_text_installed + +try: + import tensorflow_text as tf_text +except ImportError: + tf_text = None class UnicodeCharacterTokenizer(tokenizer.Tokenizer): @@ -199,6 +204,8 @@ def __init__( vocabulary_size: int = None, **kwargs, ) -> None: + assert_tf_text_installed(self.__class__.__name__) + # Check dtype and provide a default. if "dtype" not in kwargs or kwargs["dtype"] is None: kwargs["dtype"] = tf.int32 diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index a789946301..cebcc0cbdc 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -16,9 +16,14 @@ from typing import List import tensorflow as tf -import tensorflow_text as tf_text from keras_nlp.tokenizers import tokenizer +from keras_nlp.utils.tf_utils import assert_tf_text_installed + +try: + import tensorflow_text as tf_text +except ImportError: + tf_text = None # Matches whitespace and control characters. WHITESPACE_REGEX = r"|".join( @@ -183,6 +188,8 @@ def __init__( oov_token: str = "[UNK]", **kwargs, ) -> None: + assert_tf_text_installed(self.__class__.__name__) + # Check dtype and provide a default. if "dtype" not in kwargs or kwargs["dtype"] is None: kwargs["dtype"] = tf.int32 diff --git a/keras_nlp/utils/tensor_utils.py b/keras_nlp/utils/tf_utils.py similarity index 81% rename from keras_nlp/utils/tensor_utils.py rename to keras_nlp/utils/tf_utils.py index 26fc815f11..7c958b6070 100644 --- a/keras_nlp/utils/tensor_utils.py +++ b/keras_nlp/utils/tf_utils.py @@ -14,6 +14,11 @@ import tensorflow as tf +try: + import tensorflow_text +except ImportError: + tensorflow_text = None + def _decode_strings_to_utf8(inputs): """Recursively decodes to list of strings with 'utf-8' encoding.""" @@ -45,3 +50,12 @@ def tensor_to_string_list(inputs): if inputs.shape.rank != 0: list_outputs = list_outputs.tolist() return _decode_strings_to_utf8(list_outputs) + + +def assert_tf_text_installed(symbol_name): + """Detokenize and convert tensor to nested lists of python strings.""" + if tensorflow_text is None: + raise ImportError( + f"{symbol_name} requires the `tensorflow-text` package. " + "Please install with `pip install tensorflow-text`." + ) diff --git a/keras_nlp/utils/tensor_utils_test.py b/keras_nlp/utils/tf_utils_test.py similarity index 95% rename from keras_nlp/utils/tensor_utils_test.py rename to keras_nlp/utils/tf_utils_test.py index d9941f750f..3f639143e6 100644 --- a/keras_nlp/utils/tensor_utils_test.py +++ b/keras_nlp/utils/tf_utils_test.py @@ -14,7 +14,7 @@ import tensorflow as tf -from keras_nlp.utils.tensor_utils import tensor_to_string_list +from keras_nlp.utils.tf_utils import tensor_to_string_list class TensorToStringListTest(tf.test.TestCase): diff --git a/requirements-common.txt b/requirements-common.txt new file mode 100644 index 0000000000..b877a72263 --- /dev/null +++ b/requirements-common.txt @@ -0,0 +1,15 @@ +# Tooling. +packaging +black>=22 +flake8 +isort +pytest +pytest-cov +# Optional deps. +rouge-score +sentencepiece +# Examples deps. +nltk +tensorflow_datasets +wikiextractor +keras-tuner diff --git a/requirements-macos-m1.txt b/requirements-macos-m1.txt new file mode 100644 index 0000000000..c19f6aef46 --- /dev/null +++ b/requirements-macos-m1.txt @@ -0,0 +1,15 @@ +# WARNING: KerasNLP has no official support for MacOS M1 at this time. The +# following will pull required depenencies from the following external sources. +# - https://developer.apple.com/metal/tensorflow-plugin/ +# - https://github.com/sun1638650145/Libraries-and-Extensions-for-TensorFlow-for-Apple-Silicon/ +# These are not provided by Google, please review both of these dependencies +# before proceeding. + +# Core deps. +tensorflow-macos==2.9 +https://github.com/sun1638650145/Libraries-and-Extensions-for-TensorFlow-for-Apple-Silicon/releases/download/v2.9/tensorflow_text-2.9.0-cp39-cp39-macosx_11_0_arm64.whl +# The metal plugin breaks many tests, so is not enabled by default. +# tensorflow-metal==0.5.1 + +# Common deps. +-r requirements-common.txt diff --git a/requirements-nightly.txt b/requirements-nightly.txt new file mode 100644 index 0000000000..2d5daf9259 --- /dev/null +++ b/requirements-nightly.txt @@ -0,0 +1,6 @@ +# Core deps. +tf-nightly +tensorflow-text-nightly + +# Common deps. +-r requirements-common.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..5db761ea16 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +# Core deps. +tensorflow==2.10 +tensorflow-text==2.10 + +# Common deps. +-r requirements-common.txt diff --git a/setup.py b/setup.py index 9a2da4b43d..00f8942cb1 100644 --- a/setup.py +++ b/setup.py @@ -38,8 +38,10 @@ "absl-py", "numpy", "packaging", - "tensorflow", - "tensorflow-text", + # Don't require tensorflow on MacOS; tensorflow-macos will not + # satisfy the requirement. + "tensorflow; platform_system != 'Darwin'", + "tensorflow-text; platform_system != 'Darwin'", ], extras_require={ "tests": [ From 39d8ae126438789ad20bd3b63ae717f510adbfb8 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Thu, 10 Nov 2022 13:29:51 -0800 Subject: [PATCH 7/7] Conditionally import tf text (#452) In keeping with other layers, we should not rely on tf text being installed to import the library (this is useful for building keras.io for example). --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 3e094b5d27..3788747a05 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -24,10 +24,15 @@ from typing import List import tensorflow as tf -import tensorflow_text as tf_text from tensorflow import keras from keras_nlp.tokenizers import tokenizer +from keras_nlp.utils.tf_utils import assert_tf_text_installed + +try: + import tensorflow_text as tf_text +except ImportError: + tf_text = None # As python and TF handles special spaces differently, we need to # manually handle special spaces during string split. @@ -220,6 +225,8 @@ def __init__( sequence_length=None, **kwargs, ) -> None: + assert_tf_text_installed(self.__class__.__name__) + # Check dtype and provide a default. if "dtype" not in kwargs or kwargs["dtype"] is None: kwargs["dtype"] = tf.int32 @@ -241,7 +248,8 @@ def __init__( else: raise ValueError( "Vocabulary must be an file path or dictionary mapping string " - f"token to int ids. Received: `type(vocabulary)={type(vocabulary)}`." + "token to int ids. Received: " + f"`type(vocabulary)={type(vocabulary)}`." ) if isinstance(merges, str): self.merges = [bp.rstrip() for bp in tf.io.gfile.GFile(merges)]