diff --git a/keras_nlp/metrics/__init__.py b/keras_nlp/metrics/__init__.py index 36c05c1604..4c09787807 100644 --- a/keras_nlp/metrics/__init__.py +++ b/keras_nlp/metrics/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.metrics.bleu import Bleu from keras_nlp.metrics.edit_distance import EditDistance from keras_nlp.metrics.perplexity import Perplexity from keras_nlp.metrics.rouge_l import RougeL diff --git a/keras_nlp/metrics/bleu.py b/keras_nlp/metrics/bleu.py new file mode 100644 index 0000000000..1626762361 --- /dev/null +++ b/keras_nlp/metrics/bleu.py @@ -0,0 +1,386 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BLEU metric implementation.""" + +import collections +import math + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.utils.tensor_utils import tensor_to_list + +REPLACE_SUBSTRINGS = [ + ("<skipped>", ""), + ("-\n", ""), + ("\n", " "), + (""", '"'), + ("&", "&"), + ("<", "<"), + (">", ">"), +] + + +REGEX_PATTERNS = [ + # language-dependent part (assuming Western languages) + (r"([\{-\~\[-\` -\&\(-\+\:-\@\/])", r" \1 "), + # tokenize period and comma unless preceded by a digit + (r"([^0-9])([\.,])", r"\1 \2 "), + # tokenize period and comma unless followed by a digit + (r"([\.,])([^0-9])", r" \1 \2"), + # tokenize dash when preceded by a digit + (r"([0-9])(-)", r"\1 \2 "), + # If last character is "." or ",", add space. + (r"[\.,]$", r" \0 \1"), + # one space only between words + (r"\s+", r" "), +] + + +class Bleu(keras.metrics.Metric): + """BLEU metric. + + This class implements the BLEU metric. BLEU is generally used to evaluate + machine translation systems. By default, this implementation replicates + SacreBLEU, but user-defined tokenizers can be passed to deal with other + languages. + + For BLEU score, we count the number of matching n-grams in the candidate + translation and the reference text. We find the "clipped count" of matching + n-grams so as to not give a high score to a (reference, prediction) pair + with redundant, repeated tokens. Secondly, BLEU score tends to reward + shorter predictions more, which is why a brevity penalty is applied to + penalise short predictions. For more details, see the following article: + https://cloud.google.com/translate/automl/docs/evaluate#bleu. + + Note on input shapes: + For unbatched inputs, `y_pred` should be a tensor of shape `()`, and + `y_true` should be a tensor of shape `(num_references,)`. For batched + inputs, `y_pred` should be a tensor of shape `(batch_size,)`, + and `y_true` should be a tensor of shape `(batch_size, num_references)`. In + case of batched inputs, `y_true` can also be a ragged tensor of shape + `(batch_size, None)` if different samples have different number of + references. + + Args: + tokenizer: callable. A function that takes a string `tf.RaggedTensor` + (of any shape), and tokenizes the strings in the tensor. If the + tokenizer is not specified, the default tokenizer is used. The + default tokenizer replicates the behaviour of SacreBLEU's + `"tokenizer_13a"` tokenizer + (https://github.com/mjpost/sacrebleu/blob/v2.1.0/sacrebleu/tokenizers/tokenizer_13a.py). + max_order: int. The maximum n-gram order to use. For example, if + `max_order` is set to 3, unigrams, bigrams, and trigrams will be + considered. Defaults to 4. + smooth: bool. Whether to apply Lin et al. 2004 smoothing to the BLEU + score. Adds 1 to the matched n-gram count (i.e., numerator) and 1 + to the total n-gram count (i.e., denominator) for every order while + calculating precision. Defaults to False. + dtype: string or tf.dtypes.Dtype. Precision of metric computation. If + not specified, it defaults to tf.float32. + name: string. Name of the metric instance. + **kwargs: Other keyword arguments. + + References: + - [Papineni et al., 2002](https://aclanthology.org/P02-1040/) + - [SacreBLEU](https://github.com/mjpost/sacrebleu) + - [Lin et al., 2004](https://aclanthology.org/P04-1077/) + """ + + def __init__( + self, + tokenizer=None, + max_order=4, + smooth=False, + dtype=None, + name="bleu", + **kwargs, + ): + super().__init__(name=name, dtype=dtype, **kwargs) + + if not tf.as_dtype(self.dtype).is_floating: + raise ValueError( + "`dtype` must be a floating point type. " + f"Received: dtype={dtype}" + ) + + self.tokenizer = tokenizer + self.max_order = max_order + self.smooth = smooth + + self._matches = self.add_weight( + shape=(self.max_order,), + name="bleu_matches", + initializer="zeros", + dtype=self.dtype, + ) + self._possible_matches = self.add_weight( + shape=(self.max_order,), + name="bleu_possible_matches", + initializer="zeros", + dtype=self.dtype, + ) + self._translation_length = self.add_weight( + name="bleu_translation_length", + initializer="zeros", + dtype=self.dtype, + ) + self._reference_length = self.add_weight( + name="bleu_reference_length", + initializer="zeros", + dtype=self.dtype, + ) + self._bleu = self.add_weight( + name="bleu", + initializer="zeros", + dtype=self.dtype, + ) + + def _tokenizer(self, inputs): + """ + Tokenizes the input strings. By default, replicates the behaviour of + SacreBLEU's default tokenizer, namely, `tokenizer_13a`. + """ + if self.tokenizer: + return self.tokenizer(inputs) + + for pattern, replacement in REPLACE_SUBSTRINGS + REGEX_PATTERNS: + inputs = tf.strings.regex_replace( + input=inputs, + pattern=pattern, + rewrite=replacement, + replace_global=True, + name=None, + ) + inputs = tf.strings.split(inputs) + return inputs + + def _get_ngrams(self, segment, max_order): + """Extracts all n-grams up to a given maximum order from an input segment. + + Uses Python ops. Inspired from + https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py. + + Args: + segment: list. Text segment from which n-grams will be + extracted. + max_order: int. Maximum length in tokens of the n-grams returned + by this method. + """ + ngram_counts = collections.Counter() + for order in range(1, max_order + 1): + for i in range(0, len(segment) - order + 1): + ngram = tuple(segment[i : i + order]) + ngram_counts[ngram] += 1 + return ngram_counts + + def _corpus_bleu( + self, + reference_corpus, + translation_corpus, + matches_by_order, + possible_matches_by_order, + translation_length, + reference_length, + max_order=4, + smooth=False, + ): + """Corpus BLEU implementation using Python ops. + + Computes BLEU score of translated segments against one or more + references. Inspired from + https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py. + + Args: + reference_corpus: list of lists of references for each + translation. Each reference should be tokenized into a list + of tokens. + translation_corpus: list of translations to score. Each + translation should be tokenized into a list of tokens. + matches_by_order: list of floats containing the initial number + of matches for each order. + possible_matches_by_order: list of floats containing the initial + number of possible matches for each order. + translation_length: float. Initial number of tokens in all the + translations. + reference_length: float. Initial number of tokens in all the + references. + max_order: int. Maximum n-gram order to use when computing + BLEU score. + smooth: boolean. Whether or not to apply Lin et al. 2004 + smoothing. + """ + for (references, translation) in zip( + reference_corpus, translation_corpus + ): + reference_length += min(len(r) for r in references) + translation_length += len(translation) + + merged_ref_ngram_counts = collections.Counter() + for reference in references: + merged_ref_ngram_counts |= self._get_ngrams( + reference, max_order + ) + translation_ngram_counts = self._get_ngrams(translation, max_order) + overlap = translation_ngram_counts & merged_ref_ngram_counts + for ngram in overlap: + matches_by_order[len(ngram) - 1] += overlap[ngram] + for order in range(1, max_order + 1): + possible_matches = len(translation) - order + 1 + if possible_matches > 0: + possible_matches_by_order[order - 1] += possible_matches + + precisions = [0] * max_order + for i in range(0, max_order): + if smooth: + precisions[i] = (matches_by_order[i] + 1.0) / ( + possible_matches_by_order[i] + 1.0 + ) + else: + if possible_matches_by_order[i] > 0: + precisions[i] = ( + float(matches_by_order[i]) + / possible_matches_by_order[i] + ) + else: + precisions[i] = 0.0 + + if min(precisions) > 0: + p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions) + geo_mean = math.exp(p_log_sum) + else: + geo_mean = 0 + + ratio = float(translation_length) / reference_length + + if ratio > 1.0: + bp = 1.0 + else: + bp = math.exp(1 - 1.0 / ratio) + + bleu = geo_mean * bp + + return ( + bleu, + matches_by_order, + possible_matches_by_order, + translation_length, + reference_length, + ) + + def _calculate_bleu_score(self, references, translation): + references = tensor_to_list(references) + translation = tensor_to_list(translation) + + matches = self._matches.numpy() + possible_matches = self._possible_matches.numpy() + translation_length = self._translation_length.numpy() + reference_length = self._reference_length.numpy() + + ( + bleu_score, + matches, + possible_matches, + translation_length, + reference_length, + ) = self._corpus_bleu( + reference_corpus=references, + translation_corpus=translation, + matches_by_order=matches, + possible_matches_by_order=possible_matches, + translation_length=translation_length, + reference_length=reference_length, + max_order=self.max_order, + smooth=self.smooth, + ) + return ( + tf.constant(bleu_score, dtype=self.dtype), + tf.constant(matches, dtype=self.dtype), + tf.constant(possible_matches, dtype=self.dtype), + tf.constant(translation_length, dtype=self.dtype), + tf.constant(reference_length, dtype=self.dtype), + ) + + def update_state(self, y_true, y_pred, sample_weight=None): + def validate_and_fix_rank(inputs, tensor_name, base_rank=0): + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + if inputs.shape.rank == base_rank: + return inputs[tf.newaxis] + elif inputs.shape.rank == base_rank + 1: + return inputs + elif inputs.shape.rank == base_rank + 2: + if tf.shape(inputs)[-1] != 1: + raise ValueError( + f"{tensor_name} is of rank {input.shape.rank}. The " + f"last dimension must be of size 1." + ) + return tf.squeeze(inputs, axis=-1) + else: + raise ValueError( + f"{tensor_name} must be of rank {base_rank}, {base_rank+1} " + f"or {base_rank+2}. Found rank: {inputs.shape.rank}" + ) + + y_true = validate_and_fix_rank(y_true, "y_true", 1) + y_pred = validate_and_fix_rank(y_pred, "y_pred", 0) + + # Tokenize the inputs. + y_true = self._tokenizer(y_true) + y_pred = self._tokenizer(y_pred) + + ( + bleu_score, + matches, + possible_matches, + translation_length, + reference_length, + ) = tf.py_function( + func=self._calculate_bleu_score, + inp=[y_true, y_pred], + Tout=[self.dtype, self.dtype, self.dtype, self.dtype, self.dtype], + ) + + self._matches.assign(matches) + self._possible_matches.assign(possible_matches) + self._translation_length.assign(translation_length) + self._reference_length.assign(reference_length) + self._bleu.assign(bleu_score) + + def result(self): + return self._bleu + + def reset_state(self): + self._matches.assign( + tf.zeros(shape=(self.max_order,), dtype=self.dtype) + ) + self._possible_matches.assign( + tf.zeros(shape=(self.max_order,), dtype=self.dtype) + ) + self._translation_length.assign(0.0) + self._reference_length.assign(0.0) + self._bleu.assign(0.0) + + def get_config(self): + config = super().get_config() + config.update( + { + "tokenizer": self.tokenizer, + "max_order": self.max_order, + "smooth": self.smooth, + } + ) + return config diff --git a/keras_nlp/metrics/bleu_test.py b/keras_nlp/metrics/bleu_test.py new file mode 100644 index 0000000000..a2d314c639 --- /dev/null +++ b/keras_nlp/metrics/bleu_test.py @@ -0,0 +1,268 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Bleu.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.metrics import Bleu +from keras_nlp.tokenizers import ByteTokenizer + + +class BleuTest(tf.test.TestCase): + def test_initialization(self): + bleu = Bleu() + result = bleu.result() + + self.assertEqual(result, 0.0) + + def test_scalar_input(self): + bleu = Bleu(smooth=True) + y_true = [ + "He eats a sweet apple.", + "He is eating a tasty apple, isn't he?", + ] + y_pred = "He He He eats sweet apple which is a fruit." + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.212, delta=1e-3) + + def test_1d_list_input(self): + bleu = Bleu() + y_true = [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + y_pred = [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + + def test_2d_list_input(self): + bleu = Bleu() + y_true = [ + [["He eats a sweet apple."]], + [["Silicon Valley is one of my favourite shows!"]], + ] + y_pred = [ + ["He He He eats sweet apple which is a fruit."], + ["I love Silicon Valley, it's one of my favourite shows."], + ] + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + + def test_1d_tensor_input(self): + bleu = Bleu() + y_true = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + + def test_2d_tensor_input(self): + bleu = Bleu() + y_true = tf.constant( + [ + [["He eats a sweet apple."]], + [["Silicon Valley is one of my favourite shows!"]], + ] + ) + y_pred = tf.constant( + [ + ["He He He eats sweet apple which is a fruit."], + ["I love Silicon Valley, it's one of my favourite shows."], + ] + ) + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + + def test_custom_tokenizer(self): + byte_tokenizer = ByteTokenizer() + bleu = Bleu(tokenizer=byte_tokenizer) + y_true = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.609, delta=1e-3) + + def test_different_order(self): + bleu = Bleu(max_order=5) + y_true = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + bleu_val = bleu(y_true, y_pred) + self.assertAlmostEqual(bleu_val.numpy(), 0.188, delta=1e-3) + + def test_model_compile(self): + inputs = keras.Input(shape=(), dtype="string") + outputs = tf.identity(inputs) + model = keras.Model(inputs, outputs) + + model.compile(metrics=[Bleu()]) + + x = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + y = tf.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + + output = model.evaluate(x, y, return_dict=True) + self.assertAlmostEqual(output["bleu"], 0.243, delta=1e-3) + + def test_reset_state(self): + bleu = Bleu() + y_true = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + bleu.update_state(y_true, y_pred) + bleu_val = bleu.result() + self.assertNotEqual(bleu_val.numpy(), 0.0) + + bleu.reset_state() + bleu_val = bleu.result() + self.assertEqual(bleu_val, 0.0) + + def test_update_state(self): + bleu = Bleu() + y_true_1 = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred_1 = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + bleu.update_state(y_true_1, y_pred_1) + bleu_val = bleu.result() + self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3) + + y_true_2 = tf.constant(["Virat Kohli is the GOAT."]) + y_pred_2 = tf.constant("Virat Kohli is the greatest of all time!") + + bleu.update_state(y_true_2, y_pred_2) + bleu_val = bleu.result() + self.assertAlmostEqual(bleu_val.numpy(), 0.26, delta=1e-3) + + def test_merge_state_normalize(self): + bleu_1 = Bleu(smooth=True) + bleu_2 = Bleu(smooth=True) + + y_true_1 = tf.ragged.constant( + [ + ["He eats a sweet apple."], + ["Silicon Valley is one of my favourite shows!"], + ] + ) + y_pred_1 = tf.constant( + [ + "He He He eats sweet apple which is a fruit.", + "I love Silicon Valley, it's one of my favourite shows.", + ] + ) + + y_true_2 = tf.constant(["Virat Kohli is the GOAT."]) + y_pred_2 = tf.constant("Virat Kohli is the greatest of all time!") + + y_true_3 = tf.constant([["Watching Test cricket is so much fun."]]) + y_pred_3 = tf.constant(["Test is the best format in cricket."]) + + bleu_1.update_state(y_true_1, y_pred_1) + bleu_1.update_state(y_true_2, y_pred_2) + bleu_val = bleu_1.result() + self.assertAlmostEqual(bleu_val.numpy(), 0.293, delta=1e-3) + + bleu_2.update_state(y_true_3, y_pred_3) + bleu_val = bleu_2.result() + self.assertAlmostEqual(bleu_val.numpy(), 0.202, delta=1e-3) + + merged_bleu = Bleu(smooth=True) + merged_bleu.merge_state([bleu_1, bleu_2]) + bleu_val = merged_bleu.result() + self.assertAlmostEqual(bleu_val.numpy(), 0.495, delta=1e-3) + + def test_get_config(self): + byte_tokenizer = ByteTokenizer() + bleu = Bleu( + tokenizer=byte_tokenizer, + max_order=8, + smooth=True, + dtype=tf.float64, + name="bleu_test", + ) + + config = bleu.get_config() + expected_config_subset = { + "tokenizer": byte_tokenizer, + "max_order": 8, + "smooth": True, + } + self.assertEqual(config, {**config, **expected_config_subset}) diff --git a/keras_nlp/metrics/rouge_base.py b/keras_nlp/metrics/rouge_base.py index bd9f6aa674..51e4648654 100644 --- a/keras_nlp/metrics/rouge_base.py +++ b/keras_nlp/metrics/rouge_base.py @@ -48,6 +48,9 @@ class RougeBase(keras.metrics.Metric): not specified, it defaults to tf.float32. name: string. Name of the metric instance. **kwargs: Other keyword arguments. + + References: + - [Lin et al., 2004](https://aclanthology.org/W04-1013/) """ def __init__( diff --git a/keras_nlp/metrics/rouge_l.py b/keras_nlp/metrics/rouge_l.py index a4c9a25855..a9d67633b3 100644 --- a/keras_nlp/metrics/rouge_l.py +++ b/keras_nlp/metrics/rouge_l.py @@ -38,6 +38,9 @@ class RougeL(RougeBase): name: string. Name of the metric instance. **kwargs: Other keyword arguments. + References: + - [Lin et al., 2004](https://aclanthology.org/W04-1013/) + Examples: 1. Various Input Types. diff --git a/keras_nlp/metrics/rouge_n.py b/keras_nlp/metrics/rouge_n.py index b7a1522a92..68a9884dd1 100644 --- a/keras_nlp/metrics/rouge_n.py +++ b/keras_nlp/metrics/rouge_n.py @@ -40,6 +40,9 @@ class RougeN(RougeBase): name: string. Name of the metric instance. **kwargs: Other keyword arguments. + References: + - [Lin et al., 2004](https://aclanthology.org/W04-1013/) + Examples: 1. Various Input Types. diff --git a/keras_nlp/utils/tensor_utils.py b/keras_nlp/utils/tensor_utils.py index 26fc815f11..bcab0f24d8 100644 --- a/keras_nlp/utils/tensor_utils.py +++ b/keras_nlp/utils/tensor_utils.py @@ -25,16 +25,11 @@ def _decode_strings_to_utf8(inputs): return [_decode_strings_to_utf8(x) for x in inputs] -def tensor_to_string_list(inputs): - """Detokenize and convert tensor to nested lists of python strings. - - This is a convenience method which converts each byte string to a python - string. +def tensor_to_list(inputs): + """Converts a tensor to nested lists. Args: inputs: Input tensor, or dict/list/tuple of input tensors. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. """ if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)): inputs = tf.convert_to_tensor(inputs) @@ -44,4 +39,17 @@ def tensor_to_string_list(inputs): list_outputs = inputs.numpy() if inputs.shape.rank != 0: list_outputs = list_outputs.tolist() + return list_outputs + + +def tensor_to_string_list(inputs): + """Detokenize and convert tensor to nested lists of python strings. + + This is a convenience method which converts each byte string to a python + string. + + Args: + inputs: Input tensor, or dict/list/tuple of input tensors. + """ + list_outputs = tensor_to_list(inputs) return _decode_strings_to_utf8(list_outputs) diff --git a/keras_nlp/utils/tensor_utils_test.py b/keras_nlp/utils/tensor_utils_test.py index d9941f750f..bdb9d728c4 100644 --- a/keras_nlp/utils/tensor_utils_test.py +++ b/keras_nlp/utils/tensor_utils_test.py @@ -14,9 +14,27 @@ import tensorflow as tf +from keras_nlp.utils.tensor_utils import tensor_to_list from keras_nlp.utils.tensor_utils import tensor_to_string_list +class TensorToListTest(tf.test.TestCase): + def test_ragged_input(self): + input_data = tf.ragged.constant([[1, 2], [4, 5, 6]]) + list_output = tensor_to_list(input_data) + self.assertAllEqual(list_output, [[1, 2], [4, 5, 6]]) + + def test_dense_input(self): + input_data = tf.constant([[1, 2], [3, 4]]) + list_output = tensor_to_list(input_data) + self.assertAllEqual(list_output, [[1, 2], [3, 4]]) + + def test_scalar_input(self): + input_data = tf.constant(1) + list_output = tensor_to_list(input_data) + self.assertEqual(list_output, 1) + + class TensorToStringListTest(tf.test.TestCase): def test_detokenize_to_strings_for_ragged(self): input_data = tf.ragged.constant([["▀▁▂▃", "samurai"]])