From 395a29600cc86fd3e8a2e97f8d3153e2f2952cd6 Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Sat, 16 Apr 2022 12:33:07 +0530 Subject: [PATCH 01/11] Added Functions to Base Class --- keras_nlp/tokenizers/tokenizer.py | 38 +++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index baccd457d6..df8aecea88 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -14,6 +14,7 @@ from typing import List +import tensorflow as tf from tensorflow import keras @@ -129,3 +130,40 @@ def call(self, *args, mode="tokenize", training=None, **kwargs): raise ValueError( f"Unsupported tokenizer mode. Received: mode={mode}" ) + + def recursive_utf8_decoder(self, inputs, *args, **kwargs): + """Recursively decodes to list of strings with 'utf-8' encoding.""" + if str(type(inputs)) == "": + inputs = inputs.decode("utf-8") + return inputs + if str(type(inputs[0])) == "": + for i in range(len(inputs)): + inputs[i] = inputs[i].decode("utf-8") + return inputs + else: + for i in range(len(inputs)): + inputs[i] = self.recursive_utf8_decoder( + inputs[i], *args, **kwargs + ) + + def detokenize_to_strings(self, inputs, *args, **kwargs): + """Transform detokenized inputs to strings. + Args: + inputs: Input tensor, or dict/list/tuple of input tensors. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + detokenized_input = self.detokenize(inputs) + scalar = detokenized_input.shape.rank == 0 + if isinstance(detokenized_input, tf.RaggedTensor): + detokenized_input = detokenized_input.to_list() + elif isinstance(detokenized_input, tf.Tensor): + if scalar: + detokenized_input = detokenized_input.numpy() + return detokenized_input.decode("utf-8") + else: + detokenized_input = detokenized_input.numpy().tolist() + detokenized_input = list( + map(self.recursive_utf8_decoder, detokenized_input) + ) + return detokenized_input From e4fead56388147101852a6af87f425dedf295f8c Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Sun, 24 Apr 2022 15:46:47 +0530 Subject: [PATCH 02/11] Tightened Logic started Work on Tests --- keras_nlp/tokenizers/tokenizer.py | 28 +++++++++++--------------- keras_nlp/tokenizers/tokenizer_test.py | 8 ++++++++ 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index df8aecea88..35f8bcfad6 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -131,39 +131,35 @@ def call(self, *args, mode="tokenize", training=None, **kwargs): f"Unsupported tokenizer mode. Received: mode={mode}" ) - def recursive_utf8_decoder(self, inputs, *args, **kwargs): + def __recursive_utf8_decoder(self, inputs, *args, **kwargs): """Recursively decodes to list of strings with 'utf-8' encoding.""" - if str(type(inputs)) == "": + # Handles the case when the input is a scalar + if isinstance(inputs, bytes): inputs = inputs.decode("utf-8") return inputs - if str(type(inputs[0])) == "": - for i in range(len(inputs)): - inputs[i] = inputs[i].decode("utf-8") - return inputs + # Recursive calls for all other cases by iterating over elements else: for i in range(len(inputs)): - inputs[i] = self.recursive_utf8_decoder( + inputs[i] = self.__recursive_utf8_decoder( inputs[i], *args, **kwargs ) + return inputs def detokenize_to_strings(self, inputs, *args, **kwargs): """Transform detokenized inputs to strings. + Args: inputs: Input tensor, or dict/list/tuple of input tensors. *args: Additional positional arguments. **kwargs: Additional keyword arguments. """ - detokenized_input = self.detokenize(inputs) + detokenized_input = self.detokenize(inputs, *args, **kwargs) scalar = detokenized_input.shape.rank == 0 if isinstance(detokenized_input, tf.RaggedTensor): detokenized_input = detokenized_input.to_list() elif isinstance(detokenized_input, tf.Tensor): - if scalar: - detokenized_input = detokenized_input.numpy() - return detokenized_input.decode("utf-8") - else: - detokenized_input = detokenized_input.numpy().tolist() - detokenized_input = list( - map(self.recursive_utf8_decoder, detokenized_input) - ) + detokenized_input = detokenized_input.numpy() + if not scalar: + detokenized_input = detokenized_input.tolist() + detokenized_input = self.__recursive_utf8_decoder(detokenized_input) return detokenized_input diff --git a/keras_nlp/tokenizers/tokenizer_test.py b/keras_nlp/tokenizers/tokenizer_test.py index 9df3b3e111..eab5dfaf41 100644 --- a/keras_nlp/tokenizers/tokenizer_test.py +++ b/keras_nlp/tokenizers/tokenizer_test.py @@ -17,7 +17,15 @@ from keras_nlp.tokenizers.tokenizer import Tokenizer +class PassThroughTokenizer(Tokenizer): + __test__ = False # for pytest + + def tokenize(self, inputs): + return inputs + def detokenize(self, inputs): + return inputs + class SimpleTokenizer(Tokenizer): __test__ = False # for pytest From a7271ff8d546bb39a8c0ea93c3bd06588e836ffb Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Sun, 24 Apr 2022 16:22:11 +0530 Subject: [PATCH 03/11] Added tests --- keras_nlp/tokenizers/tokenizer_test.py | 53 ++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/keras_nlp/tokenizers/tokenizer_test.py b/keras_nlp/tokenizers/tokenizer_test.py index eab5dfaf41..bf9070f31e 100644 --- a/keras_nlp/tokenizers/tokenizer_test.py +++ b/keras_nlp/tokenizers/tokenizer_test.py @@ -17,15 +17,39 @@ from keras_nlp.tokenizers.tokenizer import Tokenizer + class PassThroughTokenizer(Tokenizer): __test__ = False # for pytest - def tokenize(self, inputs): - return inputs + def tokenize(self, inputs, sequence_length=None): + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + tokens = tf.strings.unicode_decode( + inputs, input_encoding="UTF-8", errors="ignore" + ) + + if sequence_length: + output_shape = tokens.shape.as_list() + output_shape[-1] = sequence_length + tokens = tokens.to_tensor(shape=output_shape) + + if scalar_input: + tokens = tf.squeeze(tokens, 0) + return tokens def detokenize(self, inputs): - return inputs - + inputs = tf.ragged.boolean_mask(inputs, tf.not_equal(inputs, 0)) + encoded_string = tf.strings.unicode_encode( + inputs, output_encoding="UTF-8", errors="ignore" + ) + return encoded_string + + class SimpleTokenizer(Tokenizer): __test__ = False # for pytest @@ -65,3 +89,24 @@ def test_functional_model(self): def test_missing_tokenize_raises(self): with self.assertRaises(NotImplementedError): Tokenizer()(["the quick brown fox"]) + + def test_detokenize_to_strings_for_ragged(self): + input_data = ["▀▁▂▃", "samurai"] + tokenizer = PassThroughTokenizer() + tokenize_output = tokenizer.tokenize(input_data) + detokenize_output = tokenizer.detokenize_to_strings(tokenize_output) + self.assertAllEqual(detokenize_output, ["▀▁▂▃", "samurai"]) + + def test_detokenize_to_strings_for_dense(self): + input_data = ["▀▁▂▃", "samurai"] + tokenizer = PassThroughTokenizer() + tokenize_output = tokenizer.tokenize(input_data, sequence_length=5) + detokenize_output = tokenizer.detokenize_to_strings(tokenize_output) + self.assertAllEqual(detokenize_output, ["▀▁▂▃", "samur"]) + + def test_detokenize_to_strings_for_scalar(self): + input_data = "▀▁▂▃" + tokenizer = PassThroughTokenizer() + tokenize_output = tokenizer.tokenize(input_data) + detokenize_output = tokenizer.detokenize_to_strings(tokenize_output) + self.assertEqual(detokenize_output, "▀▁▂▃") From d6dbe9d459f6eae38852777a1be82bb83aa0a178 Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Sun, 24 Apr 2022 16:24:55 +0530 Subject: [PATCH 04/11] Updated Docstring --- keras_nlp/tokenizers/tokenizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index 35f8bcfad6..e4a687ef7e 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -146,7 +146,8 @@ def __recursive_utf8_decoder(self, inputs, *args, **kwargs): return inputs def detokenize_to_strings(self, inputs, *args, **kwargs): - """Transform detokenized inputs to strings. + """Detokenize, then convert the output tensor to nested lists + of python strings. Args: inputs: Input tensor, or dict/list/tuple of input tensors. From 5ce809195edaea8816b9015fd9fa3d91318f1739 Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Tue, 26 Apr 2022 00:05:01 +0530 Subject: [PATCH 05/11] Fixing Tokenizer --- keras_nlp/tokenizers/tokenizer.py | 16 ++++------- keras_nlp/tokenizers/tokenizer_test.py | 38 +++++--------------------- 2 files changed, 12 insertions(+), 42 deletions(-) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index e4a687ef7e..a2d3eed216 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -131,19 +131,14 @@ def call(self, *args, mode="tokenize", training=None, **kwargs): f"Unsupported tokenizer mode. Received: mode={mode}" ) - def __recursive_utf8_decoder(self, inputs, *args, **kwargs): + def _decode_strings_to_utf8(self, inputs): """Recursively decodes to list of strings with 'utf-8' encoding.""" # Handles the case when the input is a scalar if isinstance(inputs, bytes): - inputs = inputs.decode("utf-8") - return inputs + return inputs.decode("utf-8") # Recursive calls for all other cases by iterating over elements else: - for i in range(len(inputs)): - inputs[i] = self.__recursive_utf8_decoder( - inputs[i], *args, **kwargs - ) - return inputs + return [self._decode_strings_to_utf8(x) for x in inputs] def detokenize_to_strings(self, inputs, *args, **kwargs): """Detokenize, then convert the output tensor to nested lists @@ -155,12 +150,11 @@ def detokenize_to_strings(self, inputs, *args, **kwargs): **kwargs: Additional keyword arguments. """ detokenized_input = self.detokenize(inputs, *args, **kwargs) - scalar = detokenized_input.shape.rank == 0 if isinstance(detokenized_input, tf.RaggedTensor): detokenized_input = detokenized_input.to_list() elif isinstance(detokenized_input, tf.Tensor): + scalar = detokenized_input.shape.rank == 0 detokenized_input = detokenized_input.numpy() if not scalar: detokenized_input = detokenized_input.tolist() - detokenized_input = self.__recursive_utf8_decoder(detokenized_input) - return detokenized_input + return self._decode_strings_to_utf8(detokenized_input) diff --git a/keras_nlp/tokenizers/tokenizer_test.py b/keras_nlp/tokenizers/tokenizer_test.py index bf9070f31e..ecb270fa46 100644 --- a/keras_nlp/tokenizers/tokenizer_test.py +++ b/keras_nlp/tokenizers/tokenizer_test.py @@ -21,33 +21,11 @@ class PassThroughTokenizer(Tokenizer): __test__ = False # for pytest - def tokenize(self, inputs, sequence_length=None): - if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): - inputs = tf.convert_to_tensor(inputs) - - scalar_input = inputs.shape.rank == 0 - if scalar_input: - inputs = tf.expand_dims(inputs, 0) - - tokens = tf.strings.unicode_decode( - inputs, input_encoding="UTF-8", errors="ignore" - ) - - if sequence_length: - output_shape = tokens.shape.as_list() - output_shape[-1] = sequence_length - tokens = tokens.to_tensor(shape=output_shape) - - if scalar_input: - tokens = tf.squeeze(tokens, 0) - return tokens + def tokenize(self, inputs): + return inputs def detokenize(self, inputs): - inputs = tf.ragged.boolean_mask(inputs, tf.not_equal(inputs, 0)) - encoded_string = tf.strings.unicode_encode( - inputs, output_encoding="UTF-8", errors="ignore" - ) - return encoded_string + return inputs class SimpleTokenizer(Tokenizer): @@ -91,17 +69,15 @@ def test_missing_tokenize_raises(self): Tokenizer()(["the quick brown fox"]) def test_detokenize_to_strings_for_ragged(self): - input_data = ["▀▁▂▃", "samurai"] + input_data = tf.ragged.constant([["▀▁▂▃", "samurai"]]) tokenizer = PassThroughTokenizer() - tokenize_output = tokenizer.tokenize(input_data) - detokenize_output = tokenizer.detokenize_to_strings(tokenize_output) + detokenize_output = tokenizer.detokenize_to_strings(input_data) self.assertAllEqual(detokenize_output, ["▀▁▂▃", "samurai"]) def test_detokenize_to_strings_for_dense(self): - input_data = ["▀▁▂▃", "samurai"] + input_data = tf.constant([["▀▁▂▃", "samurai"]]) tokenizer = PassThroughTokenizer() - tokenize_output = tokenizer.tokenize(input_data, sequence_length=5) - detokenize_output = tokenizer.detokenize_to_strings(tokenize_output) + detokenize_output = tokenizer.detokenize_to_strings(input_data) self.assertAllEqual(detokenize_output, ["▀▁▂▃", "samur"]) def test_detokenize_to_strings_for_scalar(self): From e1b6df8cae807cee816694f78030f6ffcad7cd41 Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Tue, 26 Apr 2022 00:09:23 +0530 Subject: [PATCH 06/11] Fixed Broken Tests --- keras_nlp/tokenizers/tokenizer_test.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/keras_nlp/tokenizers/tokenizer_test.py b/keras_nlp/tokenizers/tokenizer_test.py index ecb270fa46..8e9712194b 100644 --- a/keras_nlp/tokenizers/tokenizer_test.py +++ b/keras_nlp/tokenizers/tokenizer_test.py @@ -72,17 +72,16 @@ def test_detokenize_to_strings_for_ragged(self): input_data = tf.ragged.constant([["▀▁▂▃", "samurai"]]) tokenizer = PassThroughTokenizer() detokenize_output = tokenizer.detokenize_to_strings(input_data) - self.assertAllEqual(detokenize_output, ["▀▁▂▃", "samurai"]) + self.assertAllEqual(detokenize_output, [['▀▁▂▃', 'samurai']]) def test_detokenize_to_strings_for_dense(self): input_data = tf.constant([["▀▁▂▃", "samurai"]]) tokenizer = PassThroughTokenizer() detokenize_output = tokenizer.detokenize_to_strings(input_data) - self.assertAllEqual(detokenize_output, ["▀▁▂▃", "samur"]) + self.assertAllEqual(detokenize_output, [['▀▁▂▃', 'samurai']]) def test_detokenize_to_strings_for_scalar(self): - input_data = "▀▁▂▃" + input_data = tf.constant("▀▁▂▃") tokenizer = PassThroughTokenizer() - tokenize_output = tokenizer.tokenize(input_data) - detokenize_output = tokenizer.detokenize_to_strings(tokenize_output) + detokenize_output = tokenizer.detokenize_to_strings(input_data) self.assertEqual(detokenize_output, "▀▁▂▃") From d615cb55fd930259dbba83583a96cc6236a96c56 Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Tue, 26 Apr 2022 00:10:35 +0530 Subject: [PATCH 07/11] Ran format and lint --- keras_nlp/tokenizers/tokenizer_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/tokenizers/tokenizer_test.py b/keras_nlp/tokenizers/tokenizer_test.py index 8e9712194b..fd79067727 100644 --- a/keras_nlp/tokenizers/tokenizer_test.py +++ b/keras_nlp/tokenizers/tokenizer_test.py @@ -72,13 +72,13 @@ def test_detokenize_to_strings_for_ragged(self): input_data = tf.ragged.constant([["▀▁▂▃", "samurai"]]) tokenizer = PassThroughTokenizer() detokenize_output = tokenizer.detokenize_to_strings(input_data) - self.assertAllEqual(detokenize_output, [['▀▁▂▃', 'samurai']]) + self.assertAllEqual(detokenize_output, [["▀▁▂▃", "samurai"]]) def test_detokenize_to_strings_for_dense(self): input_data = tf.constant([["▀▁▂▃", "samurai"]]) tokenizer = PassThroughTokenizer() detokenize_output = tokenizer.detokenize_to_strings(input_data) - self.assertAllEqual(detokenize_output, [['▀▁▂▃', 'samurai']]) + self.assertAllEqual(detokenize_output, [["▀▁▂▃", "samurai"]]) def test_detokenize_to_strings_for_scalar(self): input_data = tf.constant("▀▁▂▃") From 8d981211076024036eaff10f07e220296c7d757b Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 25 Apr 2022 12:16:10 -0700 Subject: [PATCH 08/11] Fix docstring summary to fit on single line Adds a little more description as well --- keras_nlp/tokenizers/tokenizer.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index a2d3eed216..c8872ffb3a 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -133,16 +133,21 @@ def call(self, *args, mode="tokenize", training=None, **kwargs): def _decode_strings_to_utf8(self, inputs): """Recursively decodes to list of strings with 'utf-8' encoding.""" - # Handles the case when the input is a scalar if isinstance(inputs, bytes): + # Handles the case when the input is a scalar string. return inputs.decode("utf-8") - # Recursive calls for all other cases by iterating over elements else: + # Recursively iterate when input is a list. return [self._decode_strings_to_utf8(x) for x in inputs] def detokenize_to_strings(self, inputs, *args, **kwargs): - """Detokenize, then convert the output tensor to nested lists - of python strings. + """Detokenize and convert tensor to nested lists of python strings. + + This is a convenience method layered on top of `detokenize()`. This + method will call `detokenize()` and transform the output string + tensors back to python strings, by first converting output tensors + to nested lists of elements, and then converting each byte string + to a python string. Args: inputs: Input tensor, or dict/list/tuple of input tensors. @@ -153,8 +158,7 @@ def detokenize_to_strings(self, inputs, *args, **kwargs): if isinstance(detokenized_input, tf.RaggedTensor): detokenized_input = detokenized_input.to_list() elif isinstance(detokenized_input, tf.Tensor): - scalar = detokenized_input.shape.rank == 0 detokenized_input = detokenized_input.numpy() - if not scalar: + if detokenized_input.shape.rank != 0: detokenized_input = detokenized_input.tolist() return self._decode_strings_to_utf8(detokenized_input) From f84b49dcc7d7ded8f9b9e4d825afdb70ba725768 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 25 Apr 2022 12:25:50 -0700 Subject: [PATCH 09/11] Remove trailing whitespace --- keras_nlp/tokenizers/tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index c8872ffb3a..768d7a6811 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -143,7 +143,7 @@ def _decode_strings_to_utf8(self, inputs): def detokenize_to_strings(self, inputs, *args, **kwargs): """Detokenize and convert tensor to nested lists of python strings. - This is a convenience method layered on top of `detokenize()`. This + This is a convenience method layered on top of `detokenize()`. This method will call `detokenize()` and transform the output string tensors back to python strings, by first converting output tensors to nested lists of elements, and then converting each byte string From 511ccd684fe1aa89b71ca4083935fadcf98d5462 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Mon, 25 Apr 2022 12:47:17 -0700 Subject: [PATCH 10/11] fix --- keras_nlp/tokenizers/tokenizer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index 768d7a6811..cc155a0360 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -154,11 +154,11 @@ def detokenize_to_strings(self, inputs, *args, **kwargs): *args: Additional positional arguments. **kwargs: Additional keyword arguments. """ - detokenized_input = self.detokenize(inputs, *args, **kwargs) - if isinstance(detokenized_input, tf.RaggedTensor): - detokenized_input = detokenized_input.to_list() - elif isinstance(detokenized_input, tf.Tensor): - detokenized_input = detokenized_input.numpy() - if detokenized_input.shape.rank != 0: - detokenized_input = detokenized_input.tolist() - return self._decode_strings_to_utf8(detokenized_input) + tensor_outputs = self.detokenize(inputs, *args, **kwargs) + if isinstance(tensor_outputs, tf.RaggedTensor): + list_outputs = tensor_outputs.to_list() + elif isinstance(tensor_outputs, tf.Tensor): + list_outputs = tensor_outputs.numpy() + if tensor_outputs.shape.rank != 0: + list_outputs = list_outputs.tolist() + return self._decode_strings_to_utf8(list_outputs) From 9d6188f501e46b3ee56e9c4998b6c3ccea81b1d6 Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Tue, 3 May 2022 18:57:57 +0530 Subject: [PATCH 11/11] Ported tensor_to_string_list to tensor_utils --- keras_nlp/__init__.py | 1 + keras_nlp/tokenizers/tokenizer.py | 33 ------------------ keras_nlp/tokenizers/tokenizer_test.py | 28 --------------- keras_nlp/utils/tensor_utils.py | 47 ++++++++++++++++++++++++++ keras_nlp/utils/tensor_utils_test.py | 33 ++++++++++++++++++ 5 files changed, 81 insertions(+), 61 deletions(-) create mode 100644 keras_nlp/utils/tensor_utils.py create mode 100644 keras_nlp/utils/tensor_utils_test.py diff --git a/keras_nlp/__init__.py b/keras_nlp/__init__.py index 06afe69361..39173c72ef 100644 --- a/keras_nlp/__init__.py +++ b/keras_nlp/__init__.py @@ -15,5 +15,6 @@ from keras_nlp import layers from keras_nlp import metrics from keras_nlp import tokenizers +from keras_nlp import utils __version__ = "0.1.1" diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index cc155a0360..baccd457d6 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -14,7 +14,6 @@ from typing import List -import tensorflow as tf from tensorflow import keras @@ -130,35 +129,3 @@ def call(self, *args, mode="tokenize", training=None, **kwargs): raise ValueError( f"Unsupported tokenizer mode. Received: mode={mode}" ) - - def _decode_strings_to_utf8(self, inputs): - """Recursively decodes to list of strings with 'utf-8' encoding.""" - if isinstance(inputs, bytes): - # Handles the case when the input is a scalar string. - return inputs.decode("utf-8") - else: - # Recursively iterate when input is a list. - return [self._decode_strings_to_utf8(x) for x in inputs] - - def detokenize_to_strings(self, inputs, *args, **kwargs): - """Detokenize and convert tensor to nested lists of python strings. - - This is a convenience method layered on top of `detokenize()`. This - method will call `detokenize()` and transform the output string - tensors back to python strings, by first converting output tensors - to nested lists of elements, and then converting each byte string - to a python string. - - Args: - inputs: Input tensor, or dict/list/tuple of input tensors. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - """ - tensor_outputs = self.detokenize(inputs, *args, **kwargs) - if isinstance(tensor_outputs, tf.RaggedTensor): - list_outputs = tensor_outputs.to_list() - elif isinstance(tensor_outputs, tf.Tensor): - list_outputs = tensor_outputs.numpy() - if tensor_outputs.shape.rank != 0: - list_outputs = list_outputs.tolist() - return self._decode_strings_to_utf8(list_outputs) diff --git a/keras_nlp/tokenizers/tokenizer_test.py b/keras_nlp/tokenizers/tokenizer_test.py index fd79067727..9df3b3e111 100644 --- a/keras_nlp/tokenizers/tokenizer_test.py +++ b/keras_nlp/tokenizers/tokenizer_test.py @@ -18,16 +18,6 @@ from keras_nlp.tokenizers.tokenizer import Tokenizer -class PassThroughTokenizer(Tokenizer): - __test__ = False # for pytest - - def tokenize(self, inputs): - return inputs - - def detokenize(self, inputs): - return inputs - - class SimpleTokenizer(Tokenizer): __test__ = False # for pytest @@ -67,21 +57,3 @@ def test_functional_model(self): def test_missing_tokenize_raises(self): with self.assertRaises(NotImplementedError): Tokenizer()(["the quick brown fox"]) - - def test_detokenize_to_strings_for_ragged(self): - input_data = tf.ragged.constant([["▀▁▂▃", "samurai"]]) - tokenizer = PassThroughTokenizer() - detokenize_output = tokenizer.detokenize_to_strings(input_data) - self.assertAllEqual(detokenize_output, [["▀▁▂▃", "samurai"]]) - - def test_detokenize_to_strings_for_dense(self): - input_data = tf.constant([["▀▁▂▃", "samurai"]]) - tokenizer = PassThroughTokenizer() - detokenize_output = tokenizer.detokenize_to_strings(input_data) - self.assertAllEqual(detokenize_output, [["▀▁▂▃", "samurai"]]) - - def test_detokenize_to_strings_for_scalar(self): - input_data = tf.constant("▀▁▂▃") - tokenizer = PassThroughTokenizer() - detokenize_output = tokenizer.detokenize_to_strings(input_data) - self.assertEqual(detokenize_output, "▀▁▂▃") diff --git a/keras_nlp/utils/tensor_utils.py b/keras_nlp/utils/tensor_utils.py new file mode 100644 index 0000000000..26fc815f11 --- /dev/null +++ b/keras_nlp/utils/tensor_utils.py @@ -0,0 +1,47 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + + +def _decode_strings_to_utf8(inputs): + """Recursively decodes to list of strings with 'utf-8' encoding.""" + if isinstance(inputs, bytes): + # Handles the case when the input is a scalar string. + return inputs.decode("utf-8") + else: + # Recursively iterate when input is a list. + return [_decode_strings_to_utf8(x) for x in inputs] + + +def tensor_to_string_list(inputs): + """Detokenize and convert tensor to nested lists of python strings. + + This is a convenience method which converts each byte string to a python + string. + + Args: + inputs: Input tensor, or dict/list/tuple of input tensors. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)): + inputs = tf.convert_to_tensor(inputs) + if isinstance(inputs, tf.RaggedTensor): + list_outputs = inputs.to_list() + elif isinstance(inputs, tf.Tensor): + list_outputs = inputs.numpy() + if inputs.shape.rank != 0: + list_outputs = list_outputs.tolist() + return _decode_strings_to_utf8(list_outputs) diff --git a/keras_nlp/utils/tensor_utils_test.py b/keras_nlp/utils/tensor_utils_test.py new file mode 100644 index 0000000000..2c9103f3fc --- /dev/null +++ b/keras_nlp/utils/tensor_utils_test.py @@ -0,0 +1,33 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensor_utils import tensor_to_string_list + + +class TensorToStringListTest(tf.test.TestCase): + def test_detokenize_to_strings_for_ragged(self): + input_data = tf.ragged.constant([["▀▁▂▃", "samurai"]]) + detokenize_output = tensor_to_string_list(input_data) + self.assertAllEqual(detokenize_output, [["▀▁▂▃", "samurai"]]) + + def test_detokenize_to_strings_for_dense(self): + input_data = tf.constant([["▀▁▂▃", "samurai"]]) + detokenize_output = tensor_to_string_list(input_data) + self.assertAllEqual(detokenize_output, [["▀▁▂▃", "samurai"]]) + + def test_detokenize_to_strings_for_scalar(self): + input_data = tf.constant("▀▁▂▃") + detokenize_output = tensor_to_string_list(input_data) + self.assertEqual(detokenize_output, "▀▁▂▃")