From 7654761f90342a3354e80a03930d5a4fa48b6787 Mon Sep 17 00:00:00 2001 From: Aflah <72096386+aflah02@users.noreply.github.com> Date: Fri, 8 Apr 2022 01:50:25 +0530 Subject: [PATCH] Fixing rank 1 outputs for WordPieceTokenizer (#92) * Fixed Rank Issue * Testing * Testing * Fixed Test * Fixed Typo * Fixing Typo * debug * Rank0 Set * Removed Debug Statements * Added Docstring * Added Unit Test and Minor Changes in Doc String * Ran format.sh and lint.sh * Post Review Changes --- keras_nlp/tokenizers/word_piece_tokenizer.py | 11 +++++++++++ keras_nlp/tokenizers/word_piece_tokenizer_test.py | 12 ++++++++++++ 2 files changed, 23 insertions(+) diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index fab193174a..f643925a16 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -256,6 +256,12 @@ def get_config(self) -> Dict[str, Any]: return config def tokenize(self, inputs): + # Check if Input is Scalar or Not + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + scalar_input = tf.convert_to_tensor(inputs).shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) # Optionally normalize and split inputs. if self._lowercase: inputs = tf_text.case_fold_utf8(inputs) @@ -282,6 +288,11 @@ def tokenize(self, inputs): output_shape = tokens.shape.as_list() output_shape[-1] = self._sequence_length tokens = tokens.to_tensor(shape=output_shape) + # Convert to a dense output if input in scalar + if scalar_input: + tokens = tf.squeeze(tokens, 0) + tf.ensure_shape(tokens, shape=[self._sequence_length]) + return tokens def detokenize(self, inputs): diff --git a/keras_nlp/tokenizers/word_piece_tokenizer_test.py b/keras_nlp/tokenizers/word_piece_tokenizer_test.py index ab7fb1ea60..9a54f857d5 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer_test.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer_test.py @@ -163,6 +163,18 @@ def test_functional_model(self): model_output = model(input_data) self.assertAllEqual(model_output, ["the quick brown fox"]) + def test_batching_ragged_tensors(self): + tokenizer = WordPieceTokenizer( + vocabulary=["[UNK]", "a", "b", "c", "d", "e", "f"] + ) + dataset = tf.data.Dataset.from_tensor_slices(["a b c", "d e", "a f e"]) + dataset = dataset.map(tokenizer) + dataset = dataset.apply( + tf.data.experimental.dense_to_ragged_batch(batch_size=1) + ) + element = dataset.take(1).get_single_element().numpy() + self.assertAllEqual(element, [[1, 2, 3]]) + def test_from_file(self): vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt") input_data = ["the quick brown fox."]