From 7654761f90342a3354e80a03930d5a4fa48b6787 Mon Sep 17 00:00:00 2001
From: Aflah <72096386+aflah02@users.noreply.github.com>
Date: Fri, 8 Apr 2022 01:50:25 +0530
Subject: [PATCH] Fixing rank 1 outputs for WordPieceTokenizer (#92)

* Fixed Rank Issue

* Testing

* Testing

* Fixed Test

* Fixed Typo

* Fixing Typo

* debug

* Rank0 Set

* Removed Debug Statements

* Added Docstring

* Added Unit Test and Minor Changes in Doc String

* Ran format.sh and lint.sh

* Post Review Changes
---
 keras_nlp/tokenizers/word_piece_tokenizer.py      | 11 +++++++++++
 keras_nlp/tokenizers/word_piece_tokenizer_test.py | 12 ++++++++++++
 2 files changed, 23 insertions(+)

diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py
index fab193174a..f643925a16 100644
--- a/keras_nlp/tokenizers/word_piece_tokenizer.py
+++ b/keras_nlp/tokenizers/word_piece_tokenizer.py
@@ -256,6 +256,12 @@ def get_config(self) -> Dict[str, Any]:
         return config
 
     def tokenize(self, inputs):
+        # Check if Input is Scalar or Not
+        if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
+            inputs = tf.convert_to_tensor(inputs)
+        scalar_input = tf.convert_to_tensor(inputs).shape.rank == 0
+        if scalar_input:
+            inputs = tf.expand_dims(inputs, 0)
         # Optionally normalize and split inputs.
         if self._lowercase:
             inputs = tf_text.case_fold_utf8(inputs)
@@ -282,6 +288,11 @@ def tokenize(self, inputs):
             output_shape = tokens.shape.as_list()
             output_shape[-1] = self._sequence_length
             tokens = tokens.to_tensor(shape=output_shape)
+        # Convert to a dense output if input in scalar
+        if scalar_input:
+            tokens = tf.squeeze(tokens, 0)
+            tf.ensure_shape(tokens, shape=[self._sequence_length])
+
         return tokens
 
     def detokenize(self, inputs):
diff --git a/keras_nlp/tokenizers/word_piece_tokenizer_test.py b/keras_nlp/tokenizers/word_piece_tokenizer_test.py
index ab7fb1ea60..9a54f857d5 100644
--- a/keras_nlp/tokenizers/word_piece_tokenizer_test.py
+++ b/keras_nlp/tokenizers/word_piece_tokenizer_test.py
@@ -163,6 +163,18 @@ def test_functional_model(self):
         model_output = model(input_data)
         self.assertAllEqual(model_output, ["the quick brown fox"])
 
+    def test_batching_ragged_tensors(self):
+        tokenizer = WordPieceTokenizer(
+            vocabulary=["[UNK]", "a", "b", "c", "d", "e", "f"]
+        )
+        dataset = tf.data.Dataset.from_tensor_slices(["a b c", "d e", "a f e"])
+        dataset = dataset.map(tokenizer)
+        dataset = dataset.apply(
+            tf.data.experimental.dense_to_ragged_batch(batch_size=1)
+        )
+        element = dataset.take(1).get_single_element().numpy()
+        self.assertAllEqual(element, [[1, 2, 3]])
+
     def test_from_file(self):
         vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt")
         input_data = ["the quick brown fox."]