-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle [MASK] token in DebertaV3Tokenizer #759
Changes from 9 commits
2d5c0f2
027cd23
d74c340
11ff687
f18584c
ed99bcb
1cc1800
3ce2b3d
bf1c627
6de376d
afad0ce
4cc2653
6354cf2
b25d609
a978fb1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,11 @@ class DebertaV3Tokenizer(SentencePieceTokenizer): | |
`bytes` object with a serialized SentencePiece proto. See the | ||
[SentencePiece repository](https://github.com/google/sentencepiece) | ||
for more details on the format. | ||
mask_token_id: The token ID (int) of the mask token (`[MASK]`). If | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think given that most users will not need an MLM task, we should actually make this optional when "brining your own data." Something like...
Does that make sense to you? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (1) and (3) were already taken care of. I've pushed changes which solves all three cases, and resolves the other comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! Looks good. Left some thoughts below on how we could maybe make the subclass changes a bit easier by modifying the super class. |
||
`None`, the SentencePiece vocabulary is expected to have the mask | ||
token. Preset DeBERTa vocabularies do not have the mask token in the | ||
provided vocabulary files, which is why this workaround is | ||
necessary. | ||
|
||
Examples: | ||
|
||
|
@@ -65,15 +70,29 @@ class DebertaV3Tokenizer(SentencePieceTokenizer): | |
``` | ||
""" | ||
|
||
def __init__(self, proto, **kwargs): | ||
def __init__(self, proto, mask_token_id=None, **kwargs): | ||
super().__init__(proto=proto, **kwargs) | ||
|
||
# Maintain a private copy of `mask_token_id` for config purposes. | ||
self._mask_token_id = mask_token_id | ||
|
||
# Maintain a private copy of the original vocabulary; the parent class's | ||
# `get_vocabulary()` function calls `self.vocabulary_size()`, which | ||
# throws up a segmentation fault. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the segmentation fault here? I'm not sure I totally follow. Ideally we don't have to store a copy of the vocabulary. This would be a not-totally-insignificant waste of memory! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made edits! Calling |
||
self._original_vocabulary = super().get_vocabulary() | ||
|
||
# Check for necessary special tokens. | ||
cls_token = "[CLS]" | ||
sep_token = "[SEP]" | ||
pad_token = "[PAD]" | ||
for token in [cls_token, pad_token, sep_token]: | ||
if token not in self.get_vocabulary(): | ||
mask_token = "[MASK]" | ||
|
||
in_vocab_special_tokens = [cls_token, pad_token, sep_token] | ||
if mask_token_id is None: | ||
in_vocab_special_tokens = in_vocab_special_tokens + [mask_token] | ||
|
||
for token in in_vocab_special_tokens: | ||
if token not in self._original_vocabulary: | ||
raise ValueError( | ||
f"Cannot find token `'{token}'` in the provided " | ||
f"`vocabulary`. Please provide `'{token}'` in your " | ||
|
@@ -83,6 +102,49 @@ def __init__(self, proto, **kwargs): | |
self.cls_token_id = self.token_to_id(cls_token) | ||
self.sep_token_id = self.token_to_id(sep_token) | ||
self.pad_token_id = self.token_to_id(pad_token) | ||
self.mask_token_id = mask_token_id | ||
if mask_token_id is None: | ||
self.mask_token_id = self.token_to_id(mask_token) | ||
|
||
def vocabulary_size(self): | ||
vocabulary_size = super().vocabulary_size() | ||
|
||
# This is to avoid an error when `super.get_vocabulary()` is called | ||
# in `__init__()`. | ||
if not hasattr(self, "mask_token_id"): | ||
return vocabulary_size | ||
|
||
if self.mask_token_id >= vocabulary_size: | ||
return self.mask_token_id + 1 | ||
return vocabulary_size | ||
|
||
def get_vocabulary(self): | ||
vocabulary = self._original_vocabulary | ||
if self.mask_token_id >= len(vocabulary): | ||
vocabulary = vocabulary + [None] * ( | ||
self.mask_token_id - len(vocabulary) + 1 | ||
) | ||
vocabulary[self.mask_token_id] = "[MASK]" | ||
return vocabulary | ||
|
||
def id_to_token(self, id): | ||
if id == self.mask_token_id: | ||
return "[MASK]" | ||
return super().id_to_token(id) | ||
|
||
def token_to_id(self, token): | ||
if token == "[MASK]": | ||
return self.mask_token_id | ||
return int(self._sentence_piece.string_to_id(token).numpy()) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"mask_token_id": self._mask_token_id, | ||
} | ||
) | ||
return config | ||
|
||
@classproperty | ||
def presets(cls): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -159,7 +159,7 @@ def vocabulary_size(self) -> int: | |
return int(self._sentence_piece.vocab_size().numpy()) | ||
|
||
def get_vocabulary(self) -> List[str]: | ||
"""Get the size of the tokenizer vocabulary.""" | ||
"""Get the tokenizer vocabulary.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any downside to making the super class impl use |
||
return tensor_to_string_list( | ||
self._sentence_piece.id_to_string(tf.range(self.vocabulary_size())) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to remove this from all the presets now right? It is breaking tests.