-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle [MASK] token in DebertaV3Tokenizer #759
Handle [MASK] token in DebertaV3Tokenizer #759
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for tracking this down! Just have some high level comments for now.
|
||
# Maintain a private copy of the original vocabulary; the parent class's | ||
# `get_vocabulary()` function calls `self.vocabulary_size()`, which | ||
# throws up a segmentation fault. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the segmentation fault here? I'm not sure I totally follow. Ideally we don't have to store a copy of the vocabulary. This would be a not-totally-insignificant waste of memory!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made edits!
Calling super().get_vocabulary()
in __init__
causes a seg fault because SentencePieceTokenizer
calls self.vocabulary_size()
here: https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/tokenizers/sentence_piece_tokenizer.py#L161-L165. Since we change the vocabulary_size()
function in DebertaV3Tokenizer
to return a value greater than the SPM vocabulary size, this causes a seg fault.
@@ -48,6 +48,11 @@ class DebertaV3Tokenizer(SentencePieceTokenizer): | |||
`bytes` object with a serialized SentencePiece proto. See the | |||
[SentencePiece repository](https://github.com/google/sentencepiece) | |||
for more details on the format. | |||
mask_token_id: The token ID (int) of the mask token (`[MASK]`). If |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think given that most users will not need an MLM task, we should actually make this optional when "brining your own data." Something like...
- Use one of our presets.
self.mask_token_id
is set and works as expected. - Pass your own local copy of a deberta spm file and don't set anything.
self.mask_token_id
isNone
, everything works exceptDebertaMaskedLM
, which throws a friendly error message. We can cover the error in Solve #721 Deberta masklm model #732 - Optional (but already working here). Use your own custom spm file with a
"[MASK]"
token.self.mask_token_id
is set and works as expected.
Does that make sense to you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(1) and (3) were already taken care of. I've pushed changes which solves all three cases, and resolves the other comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Looks good. Left some thoughts below on how we could maybe make the subclass changes a bit easier by modifying the super class.
|
||
return ( | ||
original_vocabulary | ||
+ [None] * (self._mask_token_id - super().vocabulary_size()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we do something like "[PLACEHOLDER]"
here? Or whatever deberta does?
Seems like a bug waiting to happen to sneak None
into a list of strings.
@@ -159,7 +159,7 @@ def vocabulary_size(self) -> int: | |||
return int(self._sentence_piece.vocab_size().numpy()) | |||
|
|||
def get_vocabulary(self) -> List[str]: | |||
"""Get the size of the tokenizer vocabulary.""" | |||
"""Get the tokenizer vocabulary.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any downside to making the super class impl use self._sentence_piece.vocab_size()
instead of self.vocabulary_size()
here? Then we don't need all this indirection on the subclass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks good. Though I think we need to update the presets to fix tests.
@@ -25,7 +25,9 @@ | |||
"max_sequence_length": 512, | |||
"bucket_size": 256, | |||
}, | |||
"preprocessor_config": {}, | |||
"preprocessor_config": { | |||
"mask_token_id": 128000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to remove this from all the presets now right? It is breaking tests.
self.mask_token_id = super().vocabulary_size() | ||
|
||
def vocabulary_size(self): | ||
return max(super().vocabulary_size(), self.mask_token_id + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would write this a little longer just for clarify...
# Account for appended mask token if necessary.
sentencepiece_size = super().vocabulary_size()
if sentencepiece_size == self.mask_token_id:
return sentencepiece_size + 1
return sentencepiece_size
Thanks! |
Fix for #732 (comment).