Skip to content

Commit

Permalink
Fix lowercase bug in wordpiece tokenizer (#1543)
Browse files Browse the repository at this point in the history
* Fix lowercase bug

* Add a comment to explain

* Change mask builder

* Revert "Change mask builder"

This reverts commit 5c9f61e.
  • Loading branch information
abuelnasr0 authored Apr 4, 2024
1 parent 4b6970c commit 825b192
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
14 changes: 12 additions & 2 deletions keras_nlp/tokenizers/word_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ def pretokenize(
text = tf.expand_dims(text, 0)
if split_on_cjk and split:
text = tf.strings.regex_replace(text, CJK_REGEX, r" \0 ")
if lowercase:
text = tf_text.case_fold_utf8(text)
if strip_accents:
# Normalize unicode to NFD, which splits out accent mark characters.
text = tf_text.normalize_utf8(text, "NFD")
Expand Down Expand Up @@ -187,6 +185,18 @@ def pretokenize(
delim_regex_pattern=split_pattern,
keep_delim_regex_pattern=keep_split_pattern,
)
if lowercase:
if special_tokens_pattern is not None:
# Do not lowercase special tokens in string space. They often
# contain capital letters, e.g. `"[CLS]"`.
mask = (
tf.strings.regex_replace(text, special_tokens_pattern, "६")
== "६"
)
text = tf.where(mask, text, tf_text.case_fold_utf8(text))
else:
text = tf_text.case_fold_utf8(text)

return text


Expand Down
15 changes: 15 additions & 0 deletions keras_nlp/tokenizers/word_piece_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,21 @@ def test_special_tokens_int_dtype(self):
output = tokenizer(input_data)
self.assertAllEqual(output, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])

def test_special_tokens_with_lowecase(self):
input_data = ["[UNK] [MASK] [SEP] [PAD] [CLS] THE QUICK BROWN FOX."]
special_tokens = ["[UNK]", "[MASK]", "[SEP]", "[PAD]", "[CLS]"]
vocab_data = ["the", "qu", "##ick", "br", "##own", "fox", "."]
vocab_data = [*special_tokens, *vocab_data]

tokenizer = WordPieceTokenizer(
vocabulary=vocab_data,
lowercase=True,
special_tokens=special_tokens,
special_tokens_in_strings=True,
)
output = tokenizer(input_data)
self.assertAllEqual(output, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])

def test_cjk_tokens(self):
input_data = ["ah半推zz"]
vocab_data = ["[UNK]", "推", "敐", "乐", "半", "偷", "匕", "ah", "zz"]
Expand Down

0 comments on commit 825b192

Please sign in to comment.