From e4507da4af641b920efa7c9c6b28dcf92fadd1df Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Thu, 2 Feb 2023 17:37:08 +0530 Subject: [PATCH 1/6] init commit --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 8d7e886c1b..73aa53c5af 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -87,7 +87,7 @@ def remove_strings_from_inputs(tensor, string_to_remove): return result -def split_strings_for_bpe(inputs): +def split_strings_for_bpe(inputs, add_prefix_space): # We need to recreate the exact behavior of token presplitting in the # original gpt2 tokenizer which uses a lookahead. As re2 does not # support lookahead match, we are using an alternative insert a special @@ -230,6 +230,7 @@ def __init__( vocabulary, merges, sequence_length=None, + add_prefix_space=False, **kwargs, ) -> None: assert_tf_text_installed(self.__class__.__name__) @@ -268,6 +269,7 @@ def __init__( f"Received: `type(merges)={type(merges)}`" ) self.sequence_length = sequence_length + self.add_prefix_space = add_prefix_space # Create byte <=> unicode mapping. This is useful for handling # whitespace tokens. @@ -455,7 +457,7 @@ def tokenize(self, inputs): if scalar_input: inputs = tf.expand_dims(inputs, 0) - raw_tokens = split_strings_for_bpe(inputs) + raw_tokens = split_strings_for_bpe(inputs, self.add_prefix_space) token_row_splits = raw_tokens.row_splits flat_tokens = raw_tokens.flat_values From 4475b949fb192586744f735cfcac15ca7db23177 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Thu, 2 Feb 2023 18:17:32 +0530 Subject: [PATCH 2/6] updated --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 73aa53c5af..2550e14c41 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -87,7 +87,7 @@ def remove_strings_from_inputs(tensor, string_to_remove): return result -def split_strings_for_bpe(inputs, add_prefix_space): +def split_strings_for_bpe(inputs): # We need to recreate the exact behavior of token presplitting in the # original gpt2 tokenizer which uses a lookahead. As re2 does not # support lookahead match, we are using an alternative insert a special @@ -450,6 +450,10 @@ def loop_condition(_, mask): return merged_words def tokenize(self, inputs): + + if self.add_prefix_space: + inputs = " " + inputs + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): inputs = tf.convert_to_tensor(inputs) @@ -457,7 +461,7 @@ def tokenize(self, inputs): if scalar_input: inputs = tf.expand_dims(inputs, 0) - raw_tokens = split_strings_for_bpe(inputs, self.add_prefix_space) + raw_tokens = split_strings_for_bpe(inputs) token_row_splits = raw_tokens.row_splits flat_tokens = raw_tokens.flat_values From 4cd29601d22d68a080e77d6e68b2f4fdea1e7e52 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Thu, 2 Feb 2023 18:38:44 +0530 Subject: [PATCH 3/6] formatting + docstring change --- keras_nlp/layers/masked_lm_mask_generator.py | 6 +----- keras_nlp/tokenizers/byte_pair_tokenizer.py | 8 +++++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/keras_nlp/layers/masked_lm_mask_generator.py b/keras_nlp/layers/masked_lm_mask_generator.py index 069f35b8f6..ab03c530e9 100644 --- a/keras_nlp/layers/masked_lm_mask_generator.py +++ b/keras_nlp/layers/masked_lm_mask_generator.py @@ -147,11 +147,7 @@ def call(self, inputs): # convert dense to ragged. inputs = tf.RaggedTensor.from_tensor(inputs) - ( - token_ids, - mask_positions, - mask_ids, - ) = tf_text.mask_language_model( + (token_ids, mask_positions, mask_ids,) = tf_text.mask_language_model( inputs, item_selector=self._random_selector, mask_values_chooser=self._mask_values_chooser, diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 2550e14c41..65178fb943 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -196,6 +196,8 @@ class BytePairTokenizer(tokenizer.Tokenizer): should have one merge rule per line. sequence_length: int, defaults to None. If set, the output will be padded or truncated to the `sequence_length`. + add_prefix_space: bool, defaults to False. Whether or not to add an initial space + to the input. This allows to treat the leading word just as any other word. Examples: @@ -451,12 +453,12 @@ def loop_condition(_, mask): def tokenize(self, inputs): - if self.add_prefix_space: - inputs = " " + inputs - if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): inputs = tf.convert_to_tensor(inputs) + if self.add_prefix_space: + inputs = tf.strings.join([" ", inputs]) + scalar_input = inputs.shape.rank == 0 if scalar_input: inputs = tf.expand_dims(inputs, 0) From eb9e7fb21e68f25e2999f97aceb490f0d41580ea Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Thu, 2 Feb 2023 18:50:13 +0530 Subject: [PATCH 4/6] bumping black version --- keras_nlp/layers/masked_lm_mask_generator.py | 6 +++++- keras_nlp/tokenizers/byte_pair_tokenizer.py | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/keras_nlp/layers/masked_lm_mask_generator.py b/keras_nlp/layers/masked_lm_mask_generator.py index ab03c530e9..069f35b8f6 100644 --- a/keras_nlp/layers/masked_lm_mask_generator.py +++ b/keras_nlp/layers/masked_lm_mask_generator.py @@ -147,7 +147,11 @@ def call(self, inputs): # convert dense to ragged. inputs = tf.RaggedTensor.from_tensor(inputs) - (token_ids, mask_positions, mask_ids,) = tf_text.mask_language_model( + ( + token_ids, + mask_positions, + mask_ids, + ) = tf_text.mask_language_model( inputs, item_selector=self._random_selector, mask_values_chooser=self._mask_values_chooser, diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 65178fb943..832a233d35 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -452,7 +452,6 @@ def loop_condition(_, mask): return merged_words def tokenize(self, inputs): - if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): inputs = tf.convert_to_tensor(inputs) From 398f97964dbe891fa6ac3d6bce4a15647c8ebc7c Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Thu, 2 Feb 2023 21:30:59 +0530 Subject: [PATCH 5/6] adding unit test --- .../tokenizers/byte_pair_tokenizer_test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py index 56c2e1365e..d3cd6eb886 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py @@ -66,6 +66,24 @@ def test_tokenize_string_output(self): ) self.assertAllEqual(call_output, expected) + def test_tokenize_prefix_space(self): + input_data = ["brown.", "black."] + tokenizer = BytePairTokenizer( + vocabulary=VOCAB_PATH, + merges=MERGE_PATH, + dtype=tf.string, + add_prefix_space=True, + ) + call_output = tokenizer(input_data) + + expected = tf.ragged.constant( + [ + ["Ġbrown", "."], + ["Ġblack", "."], + ] + ) + self.assertAllEqual(call_output, expected) + def test_tokenize_scalar_input(self): input_data = "brown." encoded = self.tokenizer.tokenize(input_data) From a3c3514a706a4854e406e9d59bd39ba7877f1d9f Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Fri, 3 Feb 2023 04:19:22 +0530 Subject: [PATCH 6/6] minor docstring change --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 832a233d35..602a9e310e 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -196,8 +196,11 @@ class BytePairTokenizer(tokenizer.Tokenizer): should have one merge rule per line. sequence_length: int, defaults to None. If set, the output will be padded or truncated to the `sequence_length`. - add_prefix_space: bool, defaults to False. Whether or not to add an initial space - to the input. This allows to treat the leading word just as any other word. + add_prefix_space: bool, defaults to False. Whether or not to add an + initial space to the input. This tokenizer is whitespace aware, + and will tokenize a word with a leading space differently. Adding + a prefix space to the first word will cause it to be tokenized + equivalently to all subsequent words in the sequence. Examples: