From 494c56aba397a0547ac19ff9a1200f47fbe34a02 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Mon, 22 Aug 2022 01:45:05 -0700 Subject: [PATCH 1/6] byte pair encoder implementation --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 385 ++++++++++++++++++ .../tokenizers/byte_pair_tokenizer_test.py | 44 ++ 2 files changed, 429 insertions(+) create mode 100644 keras_nlp/tokenizers/byte_pair_tokenizer.py create mode 100644 keras_nlp/tokenizers/byte_pair_tokenizer_test.py diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py new file mode 100644 index 0000000000..303ce09acf --- /dev/null +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -0,0 +1,385 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lib2to3.pgen2 import token +from typing import Dict +from typing import List +from typing import Iterable +from venv import create + +import tensorflow as tf +import tensorflow_text as tf_text +import json + +from keras_nlp.tokenizers import tokenizer + +def bytes_to_unicode(): + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + #removes mapping an int to a whitespace character + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return bs, cs #int to string mapping + +class BytePairTokenizerCache(): + def __init__(self): + self.key2id = tf.lookup.experimental.DenseHashTable( + tf.string, tf.int64, -1, "a ", "b " + ) + self.id2value = tf.lookup.experimental.MutableHashTable( + tf.int64, tf.string, "" + ) + self.id = tf.Variable(0, dtype=tf.int64) + def lookup(self, keys): + """Look up a tensor of tokens.""" + ids = self.key2id.lookup(keys) + result = self.id2value.lookup(ids) + # Ensure output shape for graph mode. + result.set_shape([None]) + return result + + def insert(self, keys, values): + """Insert a tensor of tokens to bp words mapping""" + size = tf.cast(tf.shape(keys)[0], tf.int64) + ids = tf.range(self.id, self.id+size) + self.id.assign(self.id+size) + + self.key2id.insert(keys, ids) + self.id2value.insert(ids, values) + return ids + +def create_static_hashtable(keys, values, default): + hashtable = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer( + tf.convert_to_tensor(keys), + tf.convert_to_tensor(values), + ), + default_value=default + ) + return hashtable + + +class BytePairTokenizer(tokenizer.Tokenizer): + def __init__( + self, + vocabulary, + merges, + sequence_length: int = None, + **kwargs, + ) -> None: + # Check dtype and provide a default. + if "dtype" not in kwargs or kwargs["dtype"] is None: + kwargs["dtype"] = tf.int32 + else: + dtype = tf.dtypes.as_dtype(kwargs["dtype"]) + if not dtype.is_integer and dtype != tf.string: + raise ValueError( + "Output dtype must be an integer type or a string. " + f"Received: dtype={dtype}" + ) + + super().__init__(**kwargs) + + if isinstance(vocabulary, str): + with open(vocabulary, "r") as f: + self.vocabulary = json.load(f) + elif isinstance(vocabulary, dict): + # Make a copy. + self.vocabulary = vocabulary.copy() + else: + raise ValueError( + "Vocabulary must be an file path or dictionary mapping byte " + f"pairs to token ids. Received: vocabulary={vocabulary}." + ) + if isinstance(merges, str): + self.merges = [ + bp.rstrip() for bp in tf.io.gfile.GFile(merges) + ] + elif isinstance(merges, Iterable): + self.merges = list(merges) + else: + raise ValueError( + "Merges must be a file path or a list of merges. Recieved: " + f"merges={merges}." + ) + self.sequence_length = sequence_length + + # TODO: use dtype to cast output + self.pat = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+""" + + # Map byte to unicode. + bs, cs = bytes_to_unicode() + self.byte2unicode = create_static_hashtable(bs, cs, default='') + + # Caching. + self.cache = BytePairTokenizerCache() + + # BytePair encodings. + self.byte_pair_encoder = create_static_hashtable( + [x[0] for x in self.vocabulary.items()], + [x[1] for x in self.vocabulary.items()], + default=-1 + ) + + # Merging rankings. + self.max_bpe_rank = len(self.merges)+1 + self.bpe_ranks = create_static_hashtable( + self.merges, + list(range(len(self.merges))), + default=self.max_bpe_rank + ) + + def get_vocabulary(self) -> List[str]: + """Get the tokenizer vocabulary as a list of strings tokens.""" + return self.vocabulary.keys() + + def vocabulary_size(self) -> int: + """Get the size of the tokenizer vocabulary.""" + return len(self.vocabulary) + + def id_to_token(self, id: int) -> str: + """Convert an integer id to a string token.""" + # This will be slow, but keep memory usage down compared to building a + # . Assuming the main use case is looking up a few special tokens + # early in the vocab, this should be fine. + for token in range(self.vocabulary): + if self.vocabulary[token] == id: + return token + return None + + def token_to_id(self, token: str) -> int: + """Convert a string token to an integer id.""" + return self.vocabulary[token] + + def get_config(self): + config = super().get_config() + config.update( + { + # Ideally a vocabulary would be saved as a plain text asset in + # the saved model. We have no good way to support this + # currently, so we save the vocabulary in the config. + "vocabulary": self.vocabulary, + "merges": self.merges, + "sequence_length": self.sequence_length, + } + ) + return config + + def tokenize(self, inputs): + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + scalar_input = inputs.shape.rank == 0 + + # Regex match tokens. + raw_tokens = tf_text.regex_split(inputs, self.pat, self.pat) + token_row_splits = raw_tokens.row_splits + flatten_tokens = raw_tokens.flat_values + + # Check cache. + cache_lookup = self.cache.lookup(flatten_tokens) + cache_mask = cache_lookup == "" + + if tf.math.count_nonzero(tf.boolean_mask(cache_mask, flatten_tokens != "")) == 0: + # All elements are in cache. + result = cache_lookup + else: + # Create byte pair merges and add to cache. + unseen_tokens = tf.boolean_mask(flatten_tokens, cache_mask) + self._byte_pair_encoding(unseen_tokens) + result = self.cache.lookup(flatten_tokens) + + # Encode merged tokens. + result = tf.strings.split(result, sep=" ") + encoding = self.byte_pair_encoder.lookup(result) + + # Unflatten to match input. + encoding = tf.RaggedTensor.from_row_splits( + encoding.flat_values, + tf.gather( + encoding.row_splits, + token_row_splits + ) + ) + + # Convert to a dense output if `sequence_length` is set. + if self.sequence_length: + output_shape = encoding.shape.as_list() + output_shape[-1] = self.sequence_length + encoding = encoding.to_tensor(shape=output_shape) + # Convert to a dense output if input in scalar + if scalar_input: + encoding = tf.squeeze(encoding, 0) + tf.ensure_shape(encoding, shape=[self.sequence_length]) + + return encoding + + # Helper functions go here. + + def _encode_tokens(self, tokens): + """Map token bytes to unicode using `byte2unicode`.""" + #TODO: This could be optimized. + # Encode token bytes. + token_bytes = tf.strings.bytes_split(tokens) + flatten_bytes = token_bytes.flat_values + flatten_bytes = tf.squeeze( + tf.cast( + tf.io.decode_raw(flatten_bytes, tf.uint8), tf.int32 + ) + ) + flatten_unicode = self.byte2unicode.lookup(flatten_bytes) + token_unicode = tf.RaggedTensor.from_row_lengths( + values=flatten_unicode, + row_lengths=token_bytes.row_lengths() + ) + return token_unicode + + def _remove_empty_strings(self, tensor): + """Remove empty strings in a tensor""" + non_empty_mask = tensor != "" + flatten_indexes = tf.where(non_empty_mask) + flatten_result = tf.gather_nd(tensor, flatten_indexes) + row_lengths = tf.reduce_sum(tf.cast(non_empty_mask, tf.int64), axis=1) + result = tf.RaggedTensor.from_row_lengths( + values=flatten_result, + row_lengths=row_lengths, + ) + return result + + def _find_top_pair_and_merge(self, words, top_pair_first, top_pair_second): + """Merges the top pair in word.""" + # Get shifted word tokens. + word_pair_first = words[:, :-1] + word_pair_second = words[:, 1:] + + # Get top pair occurances. + top_pair_first = tf.expand_dims(top_pair_first, axis=1) + top_pair_second = tf.expand_dims(top_pair_second, axis=1) + top_pair_starts = tf.math.logical_and( + word_pair_first==top_pair_first, + word_pair_second==top_pair_second + ) + + # Fixing off by one indexing. + num_words = tf.shape(top_pair_starts)[0] + front_mask = tf.logical_not( + tf.concat( + [tf.fill([num_words, 1], False), top_pair_starts], 1 + ) + ) + back_mask = tf.concat( + [tf.fill([num_words, 1], False), top_pair_starts], 1 + ) + + # Filter word tokens to keep. + front = tf.where(front_mask, words, "") + # Filter `top_pair_second` tokens to merge. + back = tf.concat( + [tf.where(back_mask[:, 1:], word_pair_second, ""), tf.fill([num_words, 1], "")], 1 + ) + # Merge and clean up empty strings. + joined = tf.strings.join([front, back]) + return self._remove_empty_strings(joined) + + def _get_pairs(self, words): + return words[:, :-1], words[:, 1:] + + @tf.function + def _byte_pair_merge_loop_body(self, words, mask): + """Iterative merging process for byte pair encoding algorithm.""" + # Get all word pairs. + first, second = self._get_pairs(words) + + # Mask empty. + non_empty_mask = second.nested_row_lengths()[0] != 0 + mask = tf.logical_and(mask, non_empty_mask) + if tf.math.count_nonzero(mask) == 0: + return [words, mask] + non_empty_idxs = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask) + tmp_first = tf.ragged.boolean_mask(first, mask) + tmp_second = tf.ragged.boolean_mask(second, mask) + + # Get top word pair. + pair_hash = tf.strings.join([tmp_first, tmp_second], separator=" ") + pair_rank = self.bpe_ranks.lookup(pair_hash) + + # Get BPE pair ranks. + min_pair_rank = tf.reduce_min(pair_rank, axis=1) + not_found_mask = min_pair_rank != self.max_bpe_rank + mask = tf.tensor_scatter_nd_update( + mask, tf.expand_dims(non_empty_idxs, axis=1), not_found_mask + ) + if tf.math.count_nonzero(mask) == 0: + return [words, mask] + + masked_pair_rank = tf.ragged.boolean_mask(pair_rank, not_found_mask) + min_pair_rank_idx = tf.math.argmin( + masked_pair_rank.to_tensor(self.max_bpe_rank), axis=1 + ) + + # Get words and pairs to process. + p_words = tf.ragged.boolean_mask(words, mask) + p_first = tf.ragged.boolean_mask(first, mask) + p_second = tf.ragged.boolean_mask(second, mask) + p_min_rank_first = tf.gather(p_first, min_pair_rank_idx, batch_dims=1) + p_min_rank_second = tf.gather(p_second, min_pair_rank_idx, batch_dims=1) + + # Process merges of top pairs. + p_words = self._find_top_pair_and_merge(p_words, p_min_rank_first, p_min_rank_second) + + # Update words. + p_idxs = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask) + tensor_words = words.to_tensor(default_value="") + tensor_p_words = p_words.to_tensor( + default_value="", + shape=[tf.shape(p_idxs)[0], tf.shape(tensor_words)[1]] + ) + words = tf.tensor_scatter_nd_update( + tensor_words, + tf.expand_dims(p_idxs, axis=1), + tensor_p_words, + ) + words = self._remove_empty_strings(words) + return [words, mask] + + def _byte_pair_encoding(self, tokens): + """Process unseen tokens and add to cache.""" + words = self._encode_tokens(tokens) + num_words = tf.shape(words)[0] + + # Merge bytes. + loop_condition = lambda _, mask : tf.math.count_nonzero(mask) > 0 + initial_mask = tf.fill((num_words,), True) + merged_words, _ = tf.while_loop( + loop_condition, + self._byte_pair_merge_loop_body, + loop_vars=[words, initial_mask], + shape_invariants=[ + tf.TensorShape([None, None]), + tf.TensorShape([None]), + ] + ) + + merged_words_hash = tf.strings.reduce_join(merged_words, axis=1, separator=" ") + self.cache.insert(tokens, merged_words_hash) + diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py new file mode 100644 index 0000000000..94e4787279 --- /dev/null +++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py @@ -0,0 +1,44 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +vocab = {"!": 0, "\"": 1, "#": 2, "$": 3, "%": 4, "&": 5, "'": 6, "(": 7, ")": 8, "*": 9, "+": 10, ",": 11, "-": 12, ".": 13, "/": 14, "0": 15, "1": 16, "2": 17, "3": 18, "4": 19, "5": 20, "6": 21, "7": 22, "8": 23, "9": 24, ":": 25, ";": 26, "<": 27, "=": 28, ">": 29, "?": 30, "@": 31, "A": 32, "B": 33, "C": 34, "D": 35, "E": 36, "F": 37, "G": 38, "H": 39, "I": 40, "J": 41, "K": 42, "L": 43, "M": 44, "N": 45, "O": 46, "P": 47, "Q": 48, "R": 49, "S": 50, "T": 51, "U": 52, "V": 53, "W": 54, "X": 55, "Y": 56, "Z": 57, "[": 58, "\\": 59, "]": 60, "^": 61, "_": 62, "`": 63, "a": 64, "b": 65, "c": 66, "d": 67, "e": 68, "f": 69, "g": 70, "h": 71, "i": 72, "j": 73, "k": 74, "l": 75, "m": 76, "n": 77, "o": 78, "p": 79, "q": 80, "r": 81, "s": 82, "t": 83, "u": 84, "v": 85, "w": 86, "x": 87, "y": 88, "z": 89, "{": 90, "|": 91, "}": 92, "~": 93, "\u00a1": 94, "\u00a2": 95, "\u00a3": 96, "\u00a4": 97, "\u00a5": 98, "\u00a6": 99, "\u00a7": 100, "\u00a8": 101, "\u00a9": 102, "\u00aa": 103, "\u00ab": 104, "\u00ac": 105, "\u00ae": 106, "\u00af": 107, "\u00b0": 108, "\u00b1": 109, "\u00b2": 110, "\u00b3": 111, "\u00b4": 112, "\u00b5": 113, "\u00b6": 114, "\u00b7": 115, "\u00b8": 116, "\u00b9": 117, "\u00ba": 118, "\u00bb": 119, "\u00bc": 120, "\u00bd": 121, "\u00be": 122, "\u00bf": 123, "\u00c0": 124, "\u00c1": 125, "\u00c2": 126, "\u00c3": 127, "\u00c4": 128, "\u00c5": 129, "\u00c6": 130, "\u00c7": 131, "\u00c8": 132, "\u00c9": 133, "\u00ca": 134, "\u00cb": 135, "\u00cc": 136, "\u00cd": 137, "\u00ce": 138, "\u00cf": 139, "\u00d0": 140, "\u00d1": 141, "\u00d2": 142, "\u00d3": 143, "\u00d4": 144, "\u00d5": 145, "\u00d6": 146, "\u00d7": 147, "\u00d8": 148, "\u00d9": 149, "\u00da": 150, "\u00db": 151, "\u00dc": 152, "\u00dd": 153, "\u00de": 154, "\u00df": 155, "\u00e0": 156, "\u00e1": 157, "\u00e2": 158, "\u00e3": 159, "\u00e4": 160, "\u00e5": 161, "\u00e6": 162, "\u00e7": 163, "\u00e8": 164, "\u00e9": 165, "\u00ea": 166, "\u00eb": 167, "\u00ec": 168, "\u00ed": 169, "\u00ee": 170, "\u00ef": 171, "\u00f0": 172, "\u00f1": 173, "\u00f2": 174, "\u00f3": 175, "\u00f4": 176, "\u00f5": 177, "\u00f6": 178, "\u00f7": 179, "\u00f8": 180, "\u00f9": 181, "\u00fa": 182, "\u00fb": 183, "\u00fc": 184, "\u00fd": 185, "\u00fe": 186, "\u00ff": 187, "\u0100": 188, "\u0101": 189, "\u0102": 190, "\u0103": 191, "\u0104": 192, "\u0105": 193, "\u0106": 194, "\u0107": 195, "\u0108": 196, "\u0109": 197, "\u010a": 198, "\u010b": 199, "\u010c": 200, "\u010d": 201, "\u010e": 202, "\u010f": 203, "\u0110": 204, "\u0111": 205, "\u0112": 206, "\u0113": 207, "\u0114": 208, "\u0115": 209, "\u0116": 210, "\u0117": 211, "\u0118": 212, "\u0119": 213, "\u011a": 214, "\u011b": 215, "\u011c": 216, "\u011d": 217, "\u011e": 218, "\u011f": 219, "\u0120": 220, "\u0121": 221, "\u0122": 222, "\u0123": 223, "\u0124": 224, "\u0125": 225, "\u0126": 226, "\u0127": 227, "\u0128": 228, "\u0129": 229, "\u012a": 230, "\u012b": 231, "\u012c": 232, "\u012d": 233, "\u012e": 234, "\u012f": 235, "\u0130": 236, "\u0131": 237, "\u0132": 238, "\u0133": 239, "\u0134": 240, "\u0135": 241, "\u0136": 242, "\u0137": 243, "\u0138": 244, "\u0139": 245, "\u013a": 246, "\u013b": 247, "\u013c": 248, "\u013d": 249, "\u013e": 250, "\u013f": 251, "\u0140": 252, "\u0141": 253, "\u0142": 254, "\u0143": 255, "\u0120t": 256, "\u0120a": 257, "he": 258, "in": 259, "re": 260, "on": 261, "\u0120the": 262, "er": 263, "\u0120s": 264, "at": 265, "\u0120w": 266, "\u0120o": 267, "en": 268, "\u0120c": 269, "it": 270, "is": 271, "an": 272, "or": 273, "es": 274, "\u0120b": 275, "ed": 276, "\u0120f": 277, "ing": 278, "\u0120p": 279, "ou": 280, "\u0120an": 281, "al": 282, "ar": 283, "\u0120to": 284, "\u0120m": 285, "\u0120of": 286, "\u0120in": 287, "\u0120d": 288, "\u0120h": 289, "\u0120and": 290, "ic": 291, "as": 292, "le": 293, "\u0120th": 294, "ion": 295, "om": 296, "ll": 297, "ent": 298, "\u0120n": 299, "\u0120l": 300, "st": 301, "\u0120re": 302, "ve": 303, "\u0120e": 304, "ro": 305, "ly": 306, "\u0120be": 307, "\u0120g": 308, "\u0120T": 309, "ct": 310, "\u0120S": 311, "id": 312, "ot": 313, "\u0120I": 314, "ut": 315, "et": 316, "\u0120A": 317, "\u0120is": 318, "\u0120on": 319, "im": 320, "am": 321, "ow": 322, "ay": 323, "ad": 324, "se": 325, "\u0120that": 326, "\u0120C": 327, "ig": 328, "\u0120for": 329, "ac": 330, "\u0120y": 331, "ver": 332, "ur": 333, "\u0120u": 334, "ld": 335, "\u0120st": 336, "\u0120M": 337, "'s": 338, "\u0120he": 339, "\u0120it": 340, "ation": 341, "ith": 342, "ir": 343, "ce": 344, "\u0120you": 345, "il": 346, "\u0120B": 347, "\u0120wh": 348, "ol": 349, "\u0120P": 350, "\u0120with": 351, "\u01201": 352, "ter": 353, "ch": 354, "\u0120as": 355, "\u0120we": 356, "\u0120(": 357, "nd": 358, "ill": 359, "\u0120D": 360, "if": 361, "\u01202": 362, "ag": 363, "ers": 364, "ke": 365, "\u0120\"": 366, "\u0120H": 367, "em": 368, "\u0120con": 369, "\u0120W": 370, "\u0120R": 371, "her": 372, "\u0120was": 373, "\u0120r": 374, "od": 375, "\u0120F": 376, "ul": 377, "ate": 378, "\u0120at": 379, "ri": 380, "pp": 381, "ore": 382, "\u0120The": 383, "\u0120se": 384, "us": 385, "\u0120pro": 386, "\u0120ha": 387, "um": 388, "\u0120are": 389, "\u0120de": 390, "ain": 391, "and": 392, "\u0120or": 393, "igh": 394} + +class BytePairTokenizerTest(tf.test.TestCase): + + def setUp(self): + super().setUp() + self.vocabulary = { + "t":1, "h":2, "e":3, " ":4, "the":5, + "b":6, "r":7, "o":8, "w":9, "n":10, "brown":11, + ".":12 + } + + def test_tokenize(self): + input_data = ["brown."] + tokenizer = BytePairTokenizer( + vocabulary = self.vocabulary, + merges = ["b r", "br o", "bro w", "brow n"] + ) + call_output = tokenizer(input_data) + tokenize_output = tokenizer.tokenize(input_data) + self.assertIsInstance(call_output, tf.RaggedTensor) + self.assertAllEqual(call_output, [[11,12]]) + self.assertAllEqual(tokenize_output, [[11,12]]) \ No newline at end of file From c780a389fde680410789551601b5f2ac8a1111fd Mon Sep 17 00:00:00 2001 From: jessechancy Date: Mon, 22 Aug 2022 01:50:07 -0700 Subject: [PATCH 2/6] style fixes --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 121 +++++++++--------- .../tokenizers/byte_pair_tokenizer_test.py | 30 +++-- 2 files changed, 81 insertions(+), 70 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 303ce09acf..b1171548ff 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lib2to3.pgen2 import token -from typing import Dict -from typing import List +import json from typing import Iterable -from venv import create +from typing import List import tensorflow as tf import tensorflow_text as tf_text -import json from keras_nlp.tokenizers import tokenizer + def bytes_to_unicode(): bs = ( list(range(ord("!"), ord("~") + 1)) @@ -32,16 +30,17 @@ def bytes_to_unicode(): ) cs = bs[:] n = 0 - #removes mapping an int to a whitespace character + # removes mapping an int to a whitespace character for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8 + n) n += 1 cs = [chr(n) for n in cs] - return bs, cs #int to string mapping + return bs, cs # int to string mapping + -class BytePairTokenizerCache(): +class BytePairTokenizerCache: def __init__(self): self.key2id = tf.lookup.experimental.DenseHashTable( tf.string, tf.int64, -1, "a ", "b " @@ -50,6 +49,7 @@ def __init__(self): tf.int64, tf.string, "" ) self.id = tf.Variable(0, dtype=tf.int64) + def lookup(self, keys): """Look up a tensor of tokens.""" ids = self.key2id.lookup(keys) @@ -61,20 +61,21 @@ def lookup(self, keys): def insert(self, keys, values): """Insert a tensor of tokens to bp words mapping""" size = tf.cast(tf.shape(keys)[0], tf.int64) - ids = tf.range(self.id, self.id+size) - self.id.assign(self.id+size) + ids = tf.range(self.id, self.id + size) + self.id.assign(self.id + size) self.key2id.insert(keys, ids) self.id2value.insert(ids, values) return ids + def create_static_hashtable(keys, values, default): hashtable = tf.lookup.StaticHashTable( - tf.lookup.KeyValueTensorInitializer( - tf.convert_to_tensor(keys), - tf.convert_to_tensor(values), - ), - default_value=default + tf.lookup.KeyValueTensorInitializer( + tf.convert_to_tensor(keys), + tf.convert_to_tensor(values), + ), + default_value=default, ) return hashtable @@ -112,9 +113,7 @@ def __init__( f"pairs to token ids. Received: vocabulary={vocabulary}." ) if isinstance(merges, str): - self.merges = [ - bp.rstrip() for bp in tf.io.gfile.GFile(merges) - ] + self.merges = [bp.rstrip() for bp in tf.io.gfile.GFile(merges)] elif isinstance(merges, Iterable): self.merges = list(merges) else: @@ -125,11 +124,13 @@ def __init__( self.sequence_length = sequence_length # TODO: use dtype to cast output - self.pat = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+""" + self.pat = ( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+""" + ) # Map byte to unicode. bs, cs = bytes_to_unicode() - self.byte2unicode = create_static_hashtable(bs, cs, default='') + self.byte2unicode = create_static_hashtable(bs, cs, default="") # Caching. self.cache = BytePairTokenizerCache() @@ -138,15 +139,15 @@ def __init__( self.byte_pair_encoder = create_static_hashtable( [x[0] for x in self.vocabulary.items()], [x[1] for x in self.vocabulary.items()], - default=-1 + default=-1, ) # Merging rankings. - self.max_bpe_rank = len(self.merges)+1 + self.max_bpe_rank = len(self.merges) + 1 self.bpe_ranks = create_static_hashtable( self.merges, list(range(len(self.merges))), - default=self.max_bpe_rank + default=self.max_bpe_rank, ) def get_vocabulary(self) -> List[str]: @@ -199,8 +200,13 @@ def tokenize(self, inputs): # Check cache. cache_lookup = self.cache.lookup(flatten_tokens) cache_mask = cache_lookup == "" - - if tf.math.count_nonzero(tf.boolean_mask(cache_mask, flatten_tokens != "")) == 0: + + if ( + tf.math.count_nonzero( + tf.boolean_mask(cache_mask, flatten_tokens != "") + ) + == 0 + ): # All elements are in cache. result = cache_lookup else: @@ -208,18 +214,15 @@ def tokenize(self, inputs): unseen_tokens = tf.boolean_mask(flatten_tokens, cache_mask) self._byte_pair_encoding(unseen_tokens) result = self.cache.lookup(flatten_tokens) - + # Encode merged tokens. result = tf.strings.split(result, sep=" ") encoding = self.byte_pair_encoder.lookup(result) # Unflatten to match input. encoding = tf.RaggedTensor.from_row_splits( - encoding.flat_values, - tf.gather( - encoding.row_splits, - token_row_splits - ) + encoding.flat_values, + tf.gather(encoding.row_splits, token_row_splits), ) # Convert to a dense output if `sequence_length` is set. @@ -238,19 +241,16 @@ def tokenize(self, inputs): def _encode_tokens(self, tokens): """Map token bytes to unicode using `byte2unicode`.""" - #TODO: This could be optimized. + # TODO: This could be optimized. # Encode token bytes. token_bytes = tf.strings.bytes_split(tokens) flatten_bytes = token_bytes.flat_values flatten_bytes = tf.squeeze( - tf.cast( - tf.io.decode_raw(flatten_bytes, tf.uint8), tf.int32 - ) + tf.cast(tf.io.decode_raw(flatten_bytes, tf.uint8), tf.int32) ) flatten_unicode = self.byte2unicode.lookup(flatten_bytes) token_unicode = tf.RaggedTensor.from_row_lengths( - values=flatten_unicode, - row_lengths=token_bytes.row_lengths() + values=flatten_unicode, row_lengths=token_bytes.row_lengths() ) return token_unicode @@ -276,16 +276,14 @@ def _find_top_pair_and_merge(self, words, top_pair_first, top_pair_second): top_pair_first = tf.expand_dims(top_pair_first, axis=1) top_pair_second = tf.expand_dims(top_pair_second, axis=1) top_pair_starts = tf.math.logical_and( - word_pair_first==top_pair_first, - word_pair_second==top_pair_second + word_pair_first == top_pair_first, + word_pair_second == top_pair_second, ) - + # Fixing off by one indexing. num_words = tf.shape(top_pair_starts)[0] front_mask = tf.logical_not( - tf.concat( - [tf.fill([num_words, 1], False), top_pair_starts], 1 - ) + tf.concat([tf.fill([num_words, 1], False), top_pair_starts], 1) ) back_mask = tf.concat( [tf.fill([num_words, 1], False), top_pair_starts], 1 @@ -295,7 +293,11 @@ def _find_top_pair_and_merge(self, words, top_pair_first, top_pair_second): front = tf.where(front_mask, words, "") # Filter `top_pair_second` tokens to merge. back = tf.concat( - [tf.where(back_mask[:, 1:], word_pair_second, ""), tf.fill([num_words, 1], "")], 1 + [ + tf.where(back_mask[:, 1:], word_pair_second, ""), + tf.fill([num_words, 1], ""), + ], + 1, ) # Merge and clean up empty strings. joined = tf.strings.join([front, back]) @@ -309,7 +311,7 @@ def _byte_pair_merge_loop_body(self, words, mask): """Iterative merging process for byte pair encoding algorithm.""" # Get all word pairs. first, second = self._get_pairs(words) - + # Mask empty. non_empty_mask = second.nested_row_lengths()[0] != 0 mask = tf.logical_and(mask, non_empty_mask) @@ -331,12 +333,12 @@ def _byte_pair_merge_loop_body(self, words, mask): ) if tf.math.count_nonzero(mask) == 0: return [words, mask] - + masked_pair_rank = tf.ragged.boolean_mask(pair_rank, not_found_mask) min_pair_rank_idx = tf.math.argmin( masked_pair_rank.to_tensor(self.max_bpe_rank), axis=1 ) - + # Get words and pairs to process. p_words = tf.ragged.boolean_mask(words, mask) p_first = tf.ragged.boolean_mask(first, mask) @@ -345,30 +347,34 @@ def _byte_pair_merge_loop_body(self, words, mask): p_min_rank_second = tf.gather(p_second, min_pair_rank_idx, batch_dims=1) # Process merges of top pairs. - p_words = self._find_top_pair_and_merge(p_words, p_min_rank_first, p_min_rank_second) + p_words = self._find_top_pair_and_merge( + p_words, p_min_rank_first, p_min_rank_second + ) # Update words. p_idxs = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask) tensor_words = words.to_tensor(default_value="") tensor_p_words = p_words.to_tensor( - default_value="", - shape=[tf.shape(p_idxs)[0], tf.shape(tensor_words)[1]] + default_value="", + shape=[tf.shape(p_idxs)[0], tf.shape(tensor_words)[1]], ) words = tf.tensor_scatter_nd_update( - tensor_words, - tf.expand_dims(p_idxs, axis=1), + tensor_words, + tf.expand_dims(p_idxs, axis=1), tensor_p_words, ) words = self._remove_empty_strings(words) return [words, mask] - + def _byte_pair_encoding(self, tokens): """Process unseen tokens and add to cache.""" words = self._encode_tokens(tokens) num_words = tf.shape(words)[0] # Merge bytes. - loop_condition = lambda _, mask : tf.math.count_nonzero(mask) > 0 + def loop_condition(words, mask): + return tf.math.count_nonzero(mask) > 0 + initial_mask = tf.fill((num_words,), True) merged_words, _ = tf.while_loop( loop_condition, @@ -377,9 +383,10 @@ def _byte_pair_encoding(self, tokens): shape_invariants=[ tf.TensorShape([None, None]), tf.TensorShape([None]), - ] + ], ) - merged_words_hash = tf.strings.reduce_join(merged_words, axis=1, separator=" ") + merged_words_hash = tf.strings.reduce_join( + merged_words, axis=1, separator=" " + ) self.cache.insert(tokens, merged_words_hash) - diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py index 94e4787279..8d03c62b80 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py @@ -12,33 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import tensorflow as tf -from tensorflow import keras from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -vocab = {"!": 0, "\"": 1, "#": 2, "$": 3, "%": 4, "&": 5, "'": 6, "(": 7, ")": 8, "*": 9, "+": 10, ",": 11, "-": 12, ".": 13, "/": 14, "0": 15, "1": 16, "2": 17, "3": 18, "4": 19, "5": 20, "6": 21, "7": 22, "8": 23, "9": 24, ":": 25, ";": 26, "<": 27, "=": 28, ">": 29, "?": 30, "@": 31, "A": 32, "B": 33, "C": 34, "D": 35, "E": 36, "F": 37, "G": 38, "H": 39, "I": 40, "J": 41, "K": 42, "L": 43, "M": 44, "N": 45, "O": 46, "P": 47, "Q": 48, "R": 49, "S": 50, "T": 51, "U": 52, "V": 53, "W": 54, "X": 55, "Y": 56, "Z": 57, "[": 58, "\\": 59, "]": 60, "^": 61, "_": 62, "`": 63, "a": 64, "b": 65, "c": 66, "d": 67, "e": 68, "f": 69, "g": 70, "h": 71, "i": 72, "j": 73, "k": 74, "l": 75, "m": 76, "n": 77, "o": 78, "p": 79, "q": 80, "r": 81, "s": 82, "t": 83, "u": 84, "v": 85, "w": 86, "x": 87, "y": 88, "z": 89, "{": 90, "|": 91, "}": 92, "~": 93, "\u00a1": 94, "\u00a2": 95, "\u00a3": 96, "\u00a4": 97, "\u00a5": 98, "\u00a6": 99, "\u00a7": 100, "\u00a8": 101, "\u00a9": 102, "\u00aa": 103, "\u00ab": 104, "\u00ac": 105, "\u00ae": 106, "\u00af": 107, "\u00b0": 108, "\u00b1": 109, "\u00b2": 110, "\u00b3": 111, "\u00b4": 112, "\u00b5": 113, "\u00b6": 114, "\u00b7": 115, "\u00b8": 116, "\u00b9": 117, "\u00ba": 118, "\u00bb": 119, "\u00bc": 120, "\u00bd": 121, "\u00be": 122, "\u00bf": 123, "\u00c0": 124, "\u00c1": 125, "\u00c2": 126, "\u00c3": 127, "\u00c4": 128, "\u00c5": 129, "\u00c6": 130, "\u00c7": 131, "\u00c8": 132, "\u00c9": 133, "\u00ca": 134, "\u00cb": 135, "\u00cc": 136, "\u00cd": 137, "\u00ce": 138, "\u00cf": 139, "\u00d0": 140, "\u00d1": 141, "\u00d2": 142, "\u00d3": 143, "\u00d4": 144, "\u00d5": 145, "\u00d6": 146, "\u00d7": 147, "\u00d8": 148, "\u00d9": 149, "\u00da": 150, "\u00db": 151, "\u00dc": 152, "\u00dd": 153, "\u00de": 154, "\u00df": 155, "\u00e0": 156, "\u00e1": 157, "\u00e2": 158, "\u00e3": 159, "\u00e4": 160, "\u00e5": 161, "\u00e6": 162, "\u00e7": 163, "\u00e8": 164, "\u00e9": 165, "\u00ea": 166, "\u00eb": 167, "\u00ec": 168, "\u00ed": 169, "\u00ee": 170, "\u00ef": 171, "\u00f0": 172, "\u00f1": 173, "\u00f2": 174, "\u00f3": 175, "\u00f4": 176, "\u00f5": 177, "\u00f6": 178, "\u00f7": 179, "\u00f8": 180, "\u00f9": 181, "\u00fa": 182, "\u00fb": 183, "\u00fc": 184, "\u00fd": 185, "\u00fe": 186, "\u00ff": 187, "\u0100": 188, "\u0101": 189, "\u0102": 190, "\u0103": 191, "\u0104": 192, "\u0105": 193, "\u0106": 194, "\u0107": 195, "\u0108": 196, "\u0109": 197, "\u010a": 198, "\u010b": 199, "\u010c": 200, "\u010d": 201, "\u010e": 202, "\u010f": 203, "\u0110": 204, "\u0111": 205, "\u0112": 206, "\u0113": 207, "\u0114": 208, "\u0115": 209, "\u0116": 210, "\u0117": 211, "\u0118": 212, "\u0119": 213, "\u011a": 214, "\u011b": 215, "\u011c": 216, "\u011d": 217, "\u011e": 218, "\u011f": 219, "\u0120": 220, "\u0121": 221, "\u0122": 222, "\u0123": 223, "\u0124": 224, "\u0125": 225, "\u0126": 226, "\u0127": 227, "\u0128": 228, "\u0129": 229, "\u012a": 230, "\u012b": 231, "\u012c": 232, "\u012d": 233, "\u012e": 234, "\u012f": 235, "\u0130": 236, "\u0131": 237, "\u0132": 238, "\u0133": 239, "\u0134": 240, "\u0135": 241, "\u0136": 242, "\u0137": 243, "\u0138": 244, "\u0139": 245, "\u013a": 246, "\u013b": 247, "\u013c": 248, "\u013d": 249, "\u013e": 250, "\u013f": 251, "\u0140": 252, "\u0141": 253, "\u0142": 254, "\u0143": 255, "\u0120t": 256, "\u0120a": 257, "he": 258, "in": 259, "re": 260, "on": 261, "\u0120the": 262, "er": 263, "\u0120s": 264, "at": 265, "\u0120w": 266, "\u0120o": 267, "en": 268, "\u0120c": 269, "it": 270, "is": 271, "an": 272, "or": 273, "es": 274, "\u0120b": 275, "ed": 276, "\u0120f": 277, "ing": 278, "\u0120p": 279, "ou": 280, "\u0120an": 281, "al": 282, "ar": 283, "\u0120to": 284, "\u0120m": 285, "\u0120of": 286, "\u0120in": 287, "\u0120d": 288, "\u0120h": 289, "\u0120and": 290, "ic": 291, "as": 292, "le": 293, "\u0120th": 294, "ion": 295, "om": 296, "ll": 297, "ent": 298, "\u0120n": 299, "\u0120l": 300, "st": 301, "\u0120re": 302, "ve": 303, "\u0120e": 304, "ro": 305, "ly": 306, "\u0120be": 307, "\u0120g": 308, "\u0120T": 309, "ct": 310, "\u0120S": 311, "id": 312, "ot": 313, "\u0120I": 314, "ut": 315, "et": 316, "\u0120A": 317, "\u0120is": 318, "\u0120on": 319, "im": 320, "am": 321, "ow": 322, "ay": 323, "ad": 324, "se": 325, "\u0120that": 326, "\u0120C": 327, "ig": 328, "\u0120for": 329, "ac": 330, "\u0120y": 331, "ver": 332, "ur": 333, "\u0120u": 334, "ld": 335, "\u0120st": 336, "\u0120M": 337, "'s": 338, "\u0120he": 339, "\u0120it": 340, "ation": 341, "ith": 342, "ir": 343, "ce": 344, "\u0120you": 345, "il": 346, "\u0120B": 347, "\u0120wh": 348, "ol": 349, "\u0120P": 350, "\u0120with": 351, "\u01201": 352, "ter": 353, "ch": 354, "\u0120as": 355, "\u0120we": 356, "\u0120(": 357, "nd": 358, "ill": 359, "\u0120D": 360, "if": 361, "\u01202": 362, "ag": 363, "ers": 364, "ke": 365, "\u0120\"": 366, "\u0120H": 367, "em": 368, "\u0120con": 369, "\u0120W": 370, "\u0120R": 371, "her": 372, "\u0120was": 373, "\u0120r": 374, "od": 375, "\u0120F": 376, "ul": 377, "ate": 378, "\u0120at": 379, "ri": 380, "pp": 381, "ore": 382, "\u0120The": 383, "\u0120se": 384, "us": 385, "\u0120pro": 386, "\u0120ha": 387, "um": 388, "\u0120are": 389, "\u0120de": 390, "ain": 391, "and": 392, "\u0120or": 393, "igh": 394} - class BytePairTokenizerTest(tf.test.TestCase): - def setUp(self): super().setUp() self.vocabulary = { - "t":1, "h":2, "e":3, " ":4, "the":5, - "b":6, "r":7, "o":8, "w":9, "n":10, "brown":11, - ".":12 + "t": 1, + "h": 2, + "e": 3, + " ": 4, + "the": 5, + "b": 6, + "r": 7, + "o": 8, + "w": 9, + "n": 10, + "brown": 11, + ".": 12, } - + def test_tokenize(self): input_data = ["brown."] tokenizer = BytePairTokenizer( - vocabulary = self.vocabulary, - merges = ["b r", "br o", "bro w", "brow n"] + vocabulary=self.vocabulary, + merges=["b r", "br o", "bro w", "brow n"], ) call_output = tokenizer(input_data) tokenize_output = tokenizer.tokenize(input_data) self.assertIsInstance(call_output, tf.RaggedTensor) - self.assertAllEqual(call_output, [[11,12]]) - self.assertAllEqual(tokenize_output, [[11,12]]) \ No newline at end of file + self.assertAllEqual(call_output, [[11, 12]]) + self.assertAllEqual(tokenize_output, [[11, 12]]) From 3ca9da864b5869fdfce45b3ce8e99f49bb39a4ae Mon Sep 17 00:00:00 2001 From: jessechancy Date: Tue, 23 Aug 2022 16:16:50 -0700 Subject: [PATCH 3/6] byte pair detokenize method --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 57 ++++++++++++++----- .../tokenizers/byte_pair_tokenizer_test.py | 29 ++++++++++ 2 files changed, 71 insertions(+), 15 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index b1171548ff..cac52fd6ee 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -37,13 +37,18 @@ def bytes_to_unicode(): cs.append(2**8 + n) n += 1 cs = [chr(n) for n in cs] + bs = [n.to_bytes(1, "little") for n in bs] return bs, cs # int to string mapping class BytePairTokenizerCache: def __init__(self): self.key2id = tf.lookup.experimental.DenseHashTable( - tf.string, tf.int64, -1, "a ", "b " + tf.string, + tf.int64, + -1, + "a ", + "b ", # These tokens will never appear as keys. ) self.id2value = tf.lookup.experimental.MutableHashTable( tf.int64, tf.string, "" @@ -131,16 +136,24 @@ def __init__( # Map byte to unicode. bs, cs = bytes_to_unicode() self.byte2unicode = create_static_hashtable(bs, cs, default="") + self.unicode2byte = create_static_hashtable(cs, bs, default="") # Caching. self.cache = BytePairTokenizerCache() # BytePair encodings. + byte_pairs = [x[0] for x in self.vocabulary.items()] + byte_pair_encoding_idxs = [x[1] for x in self.vocabulary.items()] self.byte_pair_encoder = create_static_hashtable( - [x[0] for x in self.vocabulary.items()], - [x[1] for x in self.vocabulary.items()], + byte_pairs, + byte_pair_encoding_idxs, default=-1, ) + self.byte_pair_decoder = create_static_hashtable( + byte_pair_encoding_idxs, + byte_pairs, + default="", + ) # Merging rankings. self.max_bpe_rank = len(self.merges) + 1 @@ -191,6 +204,8 @@ def tokenize(self, inputs): inputs = tf.convert_to_tensor(inputs) scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) # Regex match tokens. raw_tokens = tf_text.regex_split(inputs, self.pat, self.pat) @@ -237,22 +252,34 @@ def tokenize(self, inputs): return encoding + def detokenize(self, inputs): + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + unicode_text = tf.strings.reduce_join( + self.byte_pair_decoder.lookup(inputs), axis=1 + ) + split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8") + byte_text = tf.strings.reduce_join( + self.unicode2byte.lookup(split_unicode_text) + ) + + if not scalar_input: + byte_text = tf.expand_dims(byte_text, 0) + + return byte_text + # Helper functions go here. def _encode_tokens(self, tokens): """Map token bytes to unicode using `byte2unicode`.""" - # TODO: This could be optimized. - # Encode token bytes. - token_bytes = tf.strings.bytes_split(tokens) - flatten_bytes = token_bytes.flat_values - flatten_bytes = tf.squeeze( - tf.cast(tf.io.decode_raw(flatten_bytes, tf.uint8), tf.int32) - ) - flatten_unicode = self.byte2unicode.lookup(flatten_bytes) - token_unicode = tf.RaggedTensor.from_row_lengths( - values=flatten_unicode, row_lengths=token_bytes.row_lengths() - ) - return token_unicode + split_bytes = tf.strings.bytes_split(tokens) + split_unicode = self.byte2unicode.lookup(split_bytes) + return split_unicode def _remove_empty_strings(self, tensor): """Remove empty strings in a tensor""" diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py index 8d03c62b80..84f742b854 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py @@ -46,3 +46,32 @@ def test_tokenize(self): self.assertIsInstance(call_output, tf.RaggedTensor) self.assertAllEqual(call_output, [[11, 12]]) self.assertAllEqual(tokenize_output, [[11, 12]]) + + def test_tokenize_scalar(self): + input_data = "brown." + tokenizer = BytePairTokenizer( + vocabulary=self.vocabulary, + merges=["b r", "br o", "bro w", "brow n"], + ) + tokenize_output = tokenizer.tokenize(input_data) + self.assertAllEqual(tokenize_output, [11, 12]) + + def test_tokenize_single_output(self): + # Test that output doesn't collapse to zero dimensions with one output + input_data = "brown" + tokenizer = BytePairTokenizer( + vocabulary=self.vocabulary, + merges=["b r", "br o", "bro w", "brow n"], + ) + tokenize_output = tokenizer.tokenize(input_data) + self.assertAllEqual(tokenize_output, [11]) + + def test_detokenize(self): + input_data = ["brown."] + tokenizer = BytePairTokenizer( + vocabulary=self.vocabulary, + merges=["b r", "br o", "bro w", "brow n"], + ) + tokenized_data = tokenizer.tokenize(input_data) + output_data = tokenizer.detokenize(tokenized_data) + self.assertAllEqual(input_data, output_data) From 66c013cee666874e06f05da37191d70d08840a90 Mon Sep 17 00:00:00 2001 From: jessechancy Date: Wed, 24 Aug 2022 10:43:02 -0700 Subject: [PATCH 4/6] add to init --- keras_nlp/tokenizers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_nlp/tokenizers/__init__.py b/keras_nlp/tokenizers/__init__.py index 048ea15d0a..aef896a4e8 100644 --- a/keras_nlp/tokenizers/__init__.py +++ b/keras_nlp/tokenizers/__init__.py @@ -25,3 +25,4 @@ from keras_nlp.tokenizers.word_piece_tokenizer_trainer import ( compute_word_piece_vocabulary, ) +from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer \ No newline at end of file From 22bbd0c370a4d1ebc64ea999d2bcc06048b9ee71 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 10 Oct 2022 17:47:19 -0700 Subject: [PATCH 5/6] Update BPE with simplified implementation --- keras_nlp/tokenizers/__init__.py | 2 +- keras_nlp/tokenizers/byte_pair_tokenizer.py | 208 ++++++++++---------- 2 files changed, 110 insertions(+), 100 deletions(-) diff --git a/keras_nlp/tokenizers/__init__.py b/keras_nlp/tokenizers/__init__.py index aef896a4e8..6c0d197236 100644 --- a/keras_nlp/tokenizers/__init__.py +++ b/keras_nlp/tokenizers/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer from keras_nlp.tokenizers.byte_tokenizer import ByteTokenizer from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer from keras_nlp.tokenizers.sentence_piece_tokenizer_trainer import ( @@ -25,4 +26,3 @@ from keras_nlp.tokenizers.word_piece_tokenizer_trainer import ( compute_word_piece_vocabulary, ) -from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer \ No newline at end of file diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index cac52fd6ee..28a7a822c1 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -42,36 +42,40 @@ def bytes_to_unicode(): class BytePairTokenizerCache: + """Cache that stores the encoded result of seen tokens.""" + def __init__(self): - self.key2id = tf.lookup.experimental.DenseHashTable( - tf.string, - tf.int64, - -1, - "a ", - "b ", # These tokens will never appear as keys. - ) + # `tf.lookup.experimental.MutableHashTable` does not support string to + # string mapping. So we first convert to string to an integer key, and + # use the integer key to find the value. + self.factors = tf.pow(256, tf.range(0, 8, dtype=tf.int64)) self.id2value = tf.lookup.experimental.MutableHashTable( tf.int64, tf.string, "" ) - self.id = tf.Variable(0, dtype=tf.int64) + + def get_key(self, keys): + """Get the hash key for given inputs.""" + # `tf.fingerprint` converts token to a array of uint8 of length 8, we + # need to convert it to a uint64. + return tf.squeeze( + tf.matmul( + tf.cast(tf.fingerprint(keys), dtype=tf.int64), + self.factors[:, tf.newaxis], + ), + -1, + ) def lookup(self, keys): - """Look up a tensor of tokens.""" - ids = self.key2id.lookup(keys) + """Look up the encoded outputs of given tokens.""" + ids = self.get_key(keys) result = self.id2value.lookup(ids) # Ensure output shape for graph mode. result.set_shape([None]) return result def insert(self, keys, values): - """Insert a tensor of tokens to bp words mapping""" - size = tf.cast(tf.shape(keys)[0], tf.int64) - ids = tf.range(self.id, self.id + size) - self.id.assign(self.id + size) - - self.key2id.insert(keys, ids) - self.id2value.insert(ids, values) - return ids + """Insert token <=> encoded outputs pairs.""" + self.id2value.insert(self.get_key(keys), values) def create_static_hashtable(keys, values, default): @@ -86,6 +90,24 @@ def create_static_hashtable(keys, values, default): class BytePairTokenizer(tokenizer.Tokenizer): + """Bype-pair encoder. + + This BPE encoder provides the same funtionality as official GPT2 tokenizer. + Given the same `vocabulary` and `merges`, it should provide the same output + as fairseq implementation (https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe.py). + Different from fairseq, this implementation is graph-compatible, so you can + use it within a tf.data pipeline. + + Args: + vocabulary: string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, + it should be the file path to merge rules. The merge rule file + should have one merge rule per line. + sequence_length: int, defaults to None. If set, the output will be + padded or truncated to the `sequence_length`. + """ + def __init__( self, vocabulary, @@ -110,12 +132,11 @@ def __init__( with open(vocabulary, "r") as f: self.vocabulary = json.load(f) elif isinstance(vocabulary, dict): - # Make a copy. self.vocabulary = vocabulary.copy() else: raise ValueError( - "Vocabulary must be an file path or dictionary mapping byte " - f"pairs to token ids. Received: vocabulary={vocabulary}." + "Vocabulary must be an file path or dictionary mapping string " + f"token to int ids. Received type: {type(vocabulary)}." ) if isinstance(merges, str): self.merges = [bp.rstrip() for bp in tf.io.gfile.GFile(merges)] @@ -123,39 +144,44 @@ def __init__( self.merges = list(merges) else: raise ValueError( - "Merges must be a file path or a list of merges. Recieved: " - f"merges={merges}." + "Merges must be a file path or a list of merge rules. " + f"Received type: {type(merges)}." ) self.sequence_length = sequence_length - # TODO: use dtype to cast output + # String splitting regex pattern. self.pat = ( r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+""" ) - # Map byte to unicode. - bs, cs = bytes_to_unicode() - self.byte2unicode = create_static_hashtable(bs, cs, default="") - self.unicode2byte = create_static_hashtable(cs, bs, default="") + # Create byte <=> unicode mapping. This is useful for handling + # whitespace tokens. + byte_list, unicode_list = bytes_to_unicode() + self.byte2unicode = create_static_hashtable( + byte_list, unicode_list, default="" + ) + self.unicode2byte = create_static_hashtable( + unicode_list, byte_list, default="" + ) - # Caching. self.cache = BytePairTokenizerCache() - # BytePair encodings. + # Create mapping between string tokens to int ids, and vice versa. byte_pairs = [x[0] for x in self.vocabulary.items()] byte_pair_encoding_idxs = [x[1] for x in self.vocabulary.items()] - self.byte_pair_encoder = create_static_hashtable( + self.token_to_id_map = create_static_hashtable( byte_pairs, byte_pair_encoding_idxs, default=-1, ) - self.byte_pair_decoder = create_static_hashtable( + self.id_to_token_map = create_static_hashtable( byte_pair_encoding_idxs, byte_pairs, default="", ) - # Merging rankings. + # Create ranking of merge rules, this is the same as order of merge + # pairs in `self.merges`. self.max_bpe_rank = len(self.merges) + 1 self.bpe_ranks = create_static_hashtable( self.merges, @@ -232,7 +258,7 @@ def tokenize(self, inputs): # Encode merged tokens. result = tf.strings.split(result, sep=" ") - encoding = self.byte_pair_encoder.lookup(result) + encoding = self.token_to_id_map.lookup(result) # Unflatten to match input. encoding = tf.RaggedTensor.from_row_splits( @@ -261,7 +287,7 @@ def detokenize(self, inputs): inputs = tf.expand_dims(inputs, 0) unicode_text = tf.strings.reduce_join( - self.byte_pair_decoder.lookup(inputs), axis=1 + self.id_to_token_map.lookup(inputs), axis=1 ) split_unicode_text = tf.strings.unicode_split(unicode_text, "UTF-8") byte_text = tf.strings.reduce_join( @@ -273,8 +299,6 @@ def detokenize(self, inputs): return byte_text - # Helper functions go here. - def _encode_tokens(self, tokens): """Map token bytes to unicode using `byte2unicode`.""" split_bytes = tf.strings.bytes_split(tokens) @@ -293,51 +317,15 @@ def _remove_empty_strings(self, tensor): ) return result - def _find_top_pair_and_merge(self, words, top_pair_first, top_pair_second): - """Merges the top pair in word.""" - # Get shifted word tokens. - word_pair_first = words[:, :-1] - word_pair_second = words[:, 1:] - - # Get top pair occurances. - top_pair_first = tf.expand_dims(top_pair_first, axis=1) - top_pair_second = tf.expand_dims(top_pair_second, axis=1) - top_pair_starts = tf.math.logical_and( - word_pair_first == top_pair_first, - word_pair_second == top_pair_second, - ) - - # Fixing off by one indexing. - num_words = tf.shape(top_pair_starts)[0] - front_mask = tf.logical_not( - tf.concat([tf.fill([num_words, 1], False), top_pair_starts], 1) - ) - back_mask = tf.concat( - [tf.fill([num_words, 1], False), top_pair_starts], 1 - ) - - # Filter word tokens to keep. - front = tf.where(front_mask, words, "") - # Filter `top_pair_second` tokens to merge. - back = tf.concat( - [ - tf.where(back_mask[:, 1:], word_pair_second, ""), - tf.fill([num_words, 1], ""), - ], - 1, - ) - # Merge and clean up empty strings. - joined = tf.strings.join([front, back]) - return self._remove_empty_strings(joined) - - def _get_pairs(self, words): - return words[:, :-1], words[:, 1:] - @tf.function def _byte_pair_merge_loop_body(self, words, mask): - """Iterative merging process for byte pair encoding algorithm.""" + """Iterative merging process for byte pair encoding algorithm. + + The end condition is either the word has been fully merged (list has + only one byte string), or it can no longer perform a merge. + """ # Get all word pairs. - first, second = self._get_pairs(words) + first, second = words[:, :-1], words[:, 1:] # Mask empty. non_empty_mask = second.nested_row_lengths()[0] != 0 @@ -348,13 +336,15 @@ def _byte_pair_merge_loop_body(self, words, mask): tmp_first = tf.ragged.boolean_mask(first, mask) tmp_second = tf.ragged.boolean_mask(second, mask) - # Get top word pair. - pair_hash = tf.strings.join([tmp_first, tmp_second], separator=" ") - pair_rank = self.bpe_ranks.lookup(pair_hash) + # Get byte pair ranking in merge rules. + pairs = tf.strings.join([tmp_first, tmp_second], separator=" ") + pair_rank = self.bpe_ranks.lookup(pairs) # Get BPE pair ranks. min_pair_rank = tf.reduce_min(pair_rank, axis=1) not_found_mask = min_pair_rank != self.max_bpe_rank + + # Tokens cannot be further merged are marked as finished. mask = tf.tensor_scatter_nd_update( mask, tf.expand_dims(non_empty_idxs, axis=1), not_found_mask ) @@ -367,28 +357,45 @@ def _byte_pair_merge_loop_body(self, words, mask): ) # Get words and pairs to process. - p_words = tf.ragged.boolean_mask(words, mask) - p_first = tf.ragged.boolean_mask(first, mask) - p_second = tf.ragged.boolean_mask(second, mask) - p_min_rank_first = tf.gather(p_first, min_pair_rank_idx, batch_dims=1) - p_min_rank_second = tf.gather(p_second, min_pair_rank_idx, batch_dims=1) - - # Process merges of top pairs. - p_words = self._find_top_pair_and_merge( - p_words, p_min_rank_first, p_min_rank_second + unfinished_words = tf.ragged.boolean_mask(words, mask) + + pair_left = tf.gather(unfinished_words, min_pair_rank_idx, batch_dims=1) + pair_right = tf.gather( + unfinished_words, min_pair_rank_idx + 1, batch_dims=1 + ) + + merged_pairs = tf.strings.join([pair_left, pair_right]) + empty_strs = tf.fill(tf.shape(merged_pairs), "") + + unfinished_indices = tf.cast( + tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask), dtype=tf.int64 + ) + merge_update_indices_left = tf.concat( + [ + unfinished_indices[:, tf.newaxis], + min_pair_rank_idx[:, tf.newaxis], + ], + axis=1, + ) + merge_update_indices_right = tf.concat( + [ + unfinished_indices[:, tf.newaxis], + min_pair_rank_idx[:, tf.newaxis] + 1, + ], + axis=1, ) - # Update words. - p_idxs = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask) tensor_words = words.to_tensor(default_value="") - tensor_p_words = p_words.to_tensor( - default_value="", - shape=[tf.shape(p_idxs)[0], tf.shape(tensor_words)[1]], + tensor_words = tf.tensor_scatter_nd_update( + tensor_words, + merge_update_indices_left, + merged_pairs, ) + words = tf.tensor_scatter_nd_update( tensor_words, - tf.expand_dims(p_idxs, axis=1), - tensor_p_words, + merge_update_indices_right, + empty_strs, ) words = self._remove_empty_strings(words) return [words, mask] @@ -396,7 +403,10 @@ def _byte_pair_merge_loop_body(self, words, mask): def _byte_pair_encoding(self, tokens): """Process unseen tokens and add to cache.""" words = self._encode_tokens(tokens) - num_words = tf.shape(words)[0] + if isinstance(words, tf.RaggedTensor): + num_words = words.bounding_shape(0) + else: + num_words = tf.shape(words)[0] # Merge bytes. def loop_condition(words, mask): From e1b5acf9f0b4efa593d03d77998f41283080a672 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 13 Oct 2022 18:32:26 -0700 Subject: [PATCH 6/6] handle lookahead and special whitespace tokens --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 41 +++++++++++++++++---- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 28a7a822c1..4acc55188e 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -150,9 +150,13 @@ def __init__( self.sequence_length = sequence_length # String splitting regex pattern. - self.pat = ( - r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+""" - ) + self.special_space = r"\x{a0}\x{2009}\x{202f}\x{3000}" + + self.pat1 = r"""'s|'t|'re|'ve|'m|'ll|'d + |[\s{special_space}]+[\n\r\t\f६{special_space}]| ?\p{L}+| ?[\p{N}]+ + | ?[^\s\p{L}\p{N}{special_space}]+""" + self.pat1 = self.pat1.replace("{special_space}", self.special_space) + self.pat2 = rf"""[\s६{self.special_space}]$""" # Create byte <=> unicode mapping. This is useful for handling # whitespace tokens. @@ -233,8 +237,22 @@ def tokenize(self, inputs): if scalar_input: inputs = tf.expand_dims(inputs, 0) - # Regex match tokens. - raw_tokens = tf_text.regex_split(inputs, self.pat, self.pat) + # As re2 does not support lookahead match, we are using an alternative + # to insert a special token "६" before leading space of non-space + # characters and after the trailing space, e.g., " keras" will be + # "६ keras". + inputs = tf.strings.regex_replace( + inputs, rf"( )([^\s{self.special_space}])", r"६\1\2" + ) + inputs = tf.strings.regex_replace( + inputs, rf"(\s{self.special_space})$", r"\1६" + ) + raw_tokens = tf_text.regex_split(inputs, self.pat1, self.pat1) + # Second pass splits out the last whilespace char or "६". + raw_tokens = tf_text.regex_split(raw_tokens, self.pat2, self.pat2) + if raw_tokens.shape.rank > 2: + raw_tokens = raw_tokens.merge_dims(1, 2) + raw_tokens = self._remove_whitespace_placeholder(raw_tokens) token_row_splits = raw_tokens.row_splits flatten_tokens = raw_tokens.flat_values @@ -305,9 +323,9 @@ def _encode_tokens(self, tokens): split_unicode = self.byte2unicode.lookup(split_bytes) return split_unicode - def _remove_empty_strings(self, tensor): - """Remove empty strings in a tensor""" - non_empty_mask = tensor != "" + def _remove_strings(self, tensor, string_to_remove): + """Remove certain strings from input tensor.""" + non_empty_mask = tensor != string_to_remove flatten_indexes = tf.where(non_empty_mask) flatten_result = tf.gather_nd(tensor, flatten_indexes) row_lengths = tf.reduce_sum(tf.cast(non_empty_mask, tf.int64), axis=1) @@ -317,6 +335,13 @@ def _remove_empty_strings(self, tensor): ) return result + def _remove_empty_strings(self, tensor): + """Remove empty strings in a tensor""" + return self._remove_strings(tensor, "") + + def _remove_whitespace_placeholder(self, tensor): + return self._remove_strings(tensor, "६") + @tf.function def _byte_pair_merge_loop_body(self, words, mask): """Iterative merging process for byte pair encoding algorithm.