Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle [MASK] token in DebertaV3Tokenizer #759

Merged
merged 15 commits into from
Feb 24, 2023
4 changes: 3 additions & 1 deletion keras_nlp/models/deberta_v3/deberta_v3_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def setUp(self):
unk_piece="[UNK]",
)
self.preprocessor = DebertaV3Preprocessor(
tokenizer=DebertaV3Tokenizer(proto=bytes_io.getvalue()),
tokenizer=DebertaV3Tokenizer(
proto=bytes_io.getvalue(), mask_token_id=10
),
sequence_length=12,
)
self.backbone = DebertaV3Backbone(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def setUp(self):
self.proto = bytes_io.getvalue()

self.preprocessor = DebertaV3Preprocessor(
tokenizer=DebertaV3Tokenizer(proto=self.proto),
tokenizer=DebertaV3Tokenizer(proto=self.proto, mask_token_id=10),
sequence_length=12,
)

Expand Down
20 changes: 15 additions & 5 deletions keras_nlp/models/deberta_v3/deberta_v3_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
"max_sequence_length": 512,
"bucket_size": 256,
},
"preprocessor_config": {},
"preprocessor_config": {
"mask_token_id": 128000,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to remove this from all the presets now right? It is breaking tests.

},
"metadata": {
"description": (
"12-layer DeBERTaV3 model where case is maintained. "
Expand All @@ -51,7 +53,9 @@
"max_sequence_length": 512,
"bucket_size": 256,
},
"preprocessor_config": {},
"preprocessor_config": {
"mask_token_id": 128000,
},
"metadata": {
"description": (
"6-layer DeBERTaV3 model where case is maintained. "
Expand All @@ -77,7 +81,9 @@
"max_sequence_length": 512,
"bucket_size": 256,
},
"preprocessor_config": {},
"preprocessor_config": {
"mask_token_id": 128000,
},
"metadata": {
"description": (
"12-layer DeBERTaV3 model where case is maintained. "
Expand All @@ -103,7 +109,9 @@
"max_sequence_length": 512,
"bucket_size": 256,
},
"preprocessor_config": {},
"preprocessor_config": {
"mask_token_id": 128000,
},
"metadata": {
"description": (
"24-layer DeBERTaV3 model where case is maintained. "
Expand All @@ -129,7 +137,9 @@
"max_sequence_length": 512,
"bucket_size": 256,
},
"preprocessor_config": {},
"preprocessor_config": {
"mask_token_id": 250101,
},
"metadata": {
"description": (
"12-layer DeBERTaV3 model where case is maintained. "
Expand Down
8 changes: 8 additions & 0 deletions keras_nlp/models/deberta_v3/deberta_v3_presets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def test_preprocessor_output(self):
expected_outputs = [1, 279, 1538, 2]
self.assertAllEqual(outputs, expected_outputs)

def test_preprocessor_mask_token(self):
preprocessor = DebertaV3Preprocessor.from_preset(
"deberta_v3_extra_small_en",
sequence_length=4,
)
self.assertEqual(preprocessor.tokenizer.id_to_token(128000), "[MASK]")
self.assertEqual(preprocessor.tokenizer.token_to_id("[MASK]"), 128000)

@parameterized.named_parameters(
("preset_weights", True), ("random_weights", False)
)
Expand Down
68 changes: 65 additions & 3 deletions keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ class DebertaV3Tokenizer(SentencePieceTokenizer):
`bytes` object with a serialized SentencePiece proto. See the
[SentencePiece repository](https://github.com/google/sentencepiece)
for more details on the format.
mask_token_id: The token ID (int) of the mask token (`[MASK]`). If
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think given that most users will not need an MLM task, we should actually make this optional when "brining your own data." Something like...

  • Use one of our presets. self.mask_token_id is set and works as expected.
  • Pass your own local copy of a deberta spm file and don't set anything. self.mask_token_id is None, everything works except DebertaMaskedLM, which throws a friendly error message. We can cover the error in Solve #721 Deberta masklm model #732
  • Optional (but already working here). Use your own custom spm file with a "[MASK]" token. self.mask_token_id is set and works as expected.

Does that make sense to you?

Copy link
Collaborator Author

@abheesht17 abheesht17 Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1) and (3) were already taken care of. I've pushed changes which solves all three cases, and resolves the other comment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks good. Left some thoughts below on how we could maybe make the subclass changes a bit easier by modifying the super class.

`None`, the SentencePiece vocabulary is expected to have the mask
token. Preset DeBERTa vocabularies do not have the mask token in the
provided vocabulary files, which is why this workaround is
necessary.

Examples:

Expand All @@ -65,15 +70,29 @@ class DebertaV3Tokenizer(SentencePieceTokenizer):
```
"""

def __init__(self, proto, **kwargs):
def __init__(self, proto, mask_token_id=None, **kwargs):
super().__init__(proto=proto, **kwargs)

# Maintain a private copy of `mask_token_id` for config purposes.
self._mask_token_id = mask_token_id

# Maintain a private copy of the original vocabulary; the parent class's
# `get_vocabulary()` function calls `self.vocabulary_size()`, which
# throws up a segmentation fault.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the segmentation fault here? I'm not sure I totally follow. Ideally we don't have to store a copy of the vocabulary. This would be a not-totally-insignificant waste of memory!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made edits!

Calling super().get_vocabulary() in __init__ causes a seg fault because SentencePieceTokenizer calls self.vocabulary_size() here: https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/tokenizers/sentence_piece_tokenizer.py#L161-L165. Since we change the vocabulary_size() function in DebertaV3Tokenizer to return a value greater than the SPM vocabulary size, this causes a seg fault.

self._original_vocabulary = super().get_vocabulary()

# Check for necessary special tokens.
cls_token = "[CLS]"
sep_token = "[SEP]"
pad_token = "[PAD]"
for token in [cls_token, pad_token, sep_token]:
if token not in self.get_vocabulary():
mask_token = "[MASK]"

in_vocab_special_tokens = [cls_token, pad_token, sep_token]
if mask_token_id is None:
in_vocab_special_tokens = in_vocab_special_tokens + [mask_token]

for token in in_vocab_special_tokens:
if token not in self._original_vocabulary:
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
Expand All @@ -83,6 +102,49 @@ def __init__(self, proto, **kwargs):
self.cls_token_id = self.token_to_id(cls_token)
self.sep_token_id = self.token_to_id(sep_token)
self.pad_token_id = self.token_to_id(pad_token)
self.mask_token_id = mask_token_id
if mask_token_id is None:
self.mask_token_id = self.token_to_id(mask_token)

def vocabulary_size(self):
vocabulary_size = super().vocabulary_size()

# This is to avoid an error when `super.get_vocabulary()` is called
# in `__init__()`.
if not hasattr(self, "mask_token_id"):
return vocabulary_size

if self.mask_token_id >= vocabulary_size:
return self.mask_token_id + 1
return vocabulary_size

def get_vocabulary(self):
vocabulary = self._original_vocabulary
if self.mask_token_id >= len(vocabulary):
vocabulary = vocabulary + [None] * (
self.mask_token_id - len(vocabulary) + 1
)
vocabulary[self.mask_token_id] = "[MASK]"
return vocabulary

def id_to_token(self, id):
if id == self.mask_token_id:
return "[MASK]"
return super().id_to_token(id)

def token_to_id(self, token):
if token == "[MASK]":
return self.mask_token_id
return int(self._sentence_piece.string_to_id(token).numpy())

def get_config(self):
config = super().get_config()
config.update(
{
"mask_token_id": self._mask_token_id,
}
)
return config

@classproperty
def presets(cls):
Expand Down
13 changes: 11 additions & 2 deletions keras_nlp/models/deberta_v3/deberta_v3_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def setUp(self):
)
self.proto = bytes_io.getvalue()

self.tokenizer = DebertaV3Tokenizer(proto=self.proto)
self.tokenizer = DebertaV3Tokenizer(proto=self.proto, mask_token_id=10)

def test_tokenize(self):
input_data = "the quick brown fox"
Expand All @@ -65,7 +65,16 @@ def test_detokenize(self):
self.assertEqual(output, tf.constant(["the quick brown fox"]))

def test_vocabulary_size(self):
self.assertEqual(self.tokenizer.vocabulary_size(), 10)
self.assertEqual(self.tokenizer.vocabulary_size(), 11)

def test_get_vocabulary(self):
self.assertEqual(self.tokenizer.get_vocabulary()[10], "[MASK]")

def test_id_to_token(self):
self.assertEqual(self.tokenizer.id_to_token(10), "[MASK]")

def test_token_to_id(self):
self.assertEqual(self.tokenizer.token_to_id("[MASK]"), 10)

def test_errors_missing_special_tokens(self):
bytes_io = io.BytesIO()
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/tokenizers/sentence_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def vocabulary_size(self) -> int:
return int(self._sentence_piece.vocab_size().numpy())

def get_vocabulary(self) -> List[str]:
"""Get the size of the tokenizer vocabulary."""
"""Get the tokenizer vocabulary."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any downside to making the super class impl use self._sentence_piece.vocab_size() instead of self.vocabulary_size() here? Then we don't need all this indirection on the subclass.

return tensor_to_string_list(
self._sentence_piece.id_to_string(tf.range(self.vocabulary_size()))
)
Expand Down