Skip to content

Commit

Permalink
Support String Output for BytePairTokenizer (keras-team#438)
Browse files Browse the repository at this point in the history
* Support string output for BytePairTokenizer

* Add unit test

* Minor edit
  • Loading branch information
abheesht17 authored and mattdangerw committed Nov 10, 2022
1 parent 2ebf24f commit 8ecb002
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
37 changes: 19 additions & 18 deletions keras_nlp/tokenizers/byte_pair_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

""" Byte-pair encoder implementation.
"""Byte-pair encoder implementation.
This file implements the same logic as openai BPE:
https://github.com/openai/gpt-2/blob/master/src/encoder.py,
Expand Down Expand Up @@ -159,12 +159,12 @@ def create_static_hashtable(keys, values, default):
class BytePairTokenizer(tokenizer.Tokenizer):
"""Bype-pair encoding tokenizer layer.
This BPE tokenizer provides the same funtionality as official GPT2
This BPE tokenizer provides the same functionality as the official GPT-2
tokenizer. Given the same `vocabulary` which maps tokens to ids, and `merges`
which describes BPE merge rules, it should provide the same output
as openai implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py).
Different from openai, this implementation is graph-compatible, so you can
use it within a tf.data pipeline.
as OpenAI implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py).
Different from OpenAI, this implementation is graph-compatible, so you can
use it within a `tf.data` pipeline.
If input is a batch of strings (rank > 0):
By default, the layer will output a `tf.RaggedTensor` where the last
Expand All @@ -187,7 +187,7 @@ class BytePairTokenizer(tokenizer.Tokenizer):
Examples:
Use in-momery vocabulary and merge list.
Use in-memory vocabulary and merge list.
>>> vocab = {"butter": 1, "fly": 2}
>>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"]
Expand Down Expand Up @@ -244,7 +244,7 @@ def __init__(
kwargs["dtype"] = tf.int32
else:
dtype = tf.dtypes.as_dtype(kwargs["dtype"])
if not dtype.is_integer:
if not dtype.is_integer and dtype != tf.string:
raise ValueError(
"Output dtype must be an integer type or a string. "
f"Received: `dtype={dtype}`"
Expand Down Expand Up @@ -484,28 +484,29 @@ def process_unseen_tokens():
lambda: cache_lookup,
)

# Encode merged tokens.
tokenized_words = tf.strings.split(tokenized_words, sep=" ")
encoding = self.token_to_id_map.lookup(tokenized_words)
tokens = tf.strings.split(tokenized_words, sep=" ")
if self.compute_dtype != tf.string:
# Encode merged tokens.
tokens = self.token_to_id_map.lookup(tokens)

# Unflatten to match input.
encoding = tf.RaggedTensor.from_row_splits(
encoding.flat_values,
tf.gather(encoding.row_splits, token_row_splits),
tokens = tf.RaggedTensor.from_row_splits(
tokens.flat_values,
tf.gather(tokens.row_splits, token_row_splits),
)

# Convert to a dense output if `sequence_length` is set.
if self.sequence_length:
output_shape = encoding.shape.as_list()
output_shape = tokens.shape.as_list()
output_shape[-1] = self.sequence_length
encoding = encoding.to_tensor(shape=output_shape)
tokens = tokens.to_tensor(shape=output_shape)

# Convert to a dense output if input in scalar
if scalar_input:
encoding = tf.squeeze(encoding, 0)
tf.ensure_shape(encoding, shape=[self.sequence_length])
tokens = tf.squeeze(tokens, 0)
tf.ensure_shape(tokens, shape=[self.sequence_length])

return encoding
return tokens

def detokenize(self, inputs):
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
Expand Down
14 changes: 14 additions & 0 deletions keras_nlp/tokenizers/byte_pair_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ def test_tokenize_list_input(self):
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, expected)

def test_tokenize_string_output(self):
input_data = ["quick brown fox.", "slow black bear."]
tokenizer = BytePairTokenizer(
vocabulary=VOCAB_PATH, merges=MERGE_PATH, dtype=tf.string
)
call_output = tokenizer(input_data)
expected = tf.ragged.constant(
[
["quick", "Ġbrown", "Ġfox", "."],
["slow", "Ġblack", "Ġbear", "."],
]
)
self.assertAllEqual(call_output, expected)

def test_tokenize_scalar_input(self):
input_data = "brown."
encoded = self.tokenizer.tokenize(input_data)
Expand Down

0 comments on commit 8ecb002

Please sign in to comment.