Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable option for subword regularization in XLMRobertaTokenizer #11149

Merged
merged 15 commits into from
Apr 23, 2021
Merged
26 changes: 24 additions & 2 deletions src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,18 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
modeling. This is the token which the model will try to predict.
additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`["<s>NOTUSED", "</s>NOTUSED"]`):
Additional special tokens used by the tokenizer.
enable_sampling (:obj:`bool`, `optional`, defaults to :obj:`False`):
Enable subword regularization.
nbest_size (:obj:`int`, `optional`, defaults to -1):
Sampling parameters for unigram. Invalid for BPE-Dropout.
nbest_size = {0,1}: No sampling is performed.
nbest_size > 1: samples from the nbest_size results.
nbest_size < 0: assuming that nbest_size is infinite and samples
from the all hypothesis (lattice) using
forward-filtering-and-backward-sampling algorithm.
alpha (:obj:`float`, `optional`, defaults to 0.1):
Soothing parameter for unigram sampling, and dropout probability of
merge operations for BPE-dropout.

Attributes:
sp_model (:obj:`SentencePieceProcessor`):
Expand All @@ -115,6 +127,9 @@ def __init__(
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
enable_sampling=False,
nbest_size=-1,
alpha=0.1,
**kwargs
):
# Mask token behave like a normal word, i.e. include the space before it
Expand All @@ -128,10 +143,17 @@ def __init__(
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
enable_sampling=enable_sampling,
nbest_size=nbest_size,
alpha=alpha,
**kwargs,
)

self.sp_model = spm.SentencePieceProcessor()
self.sp_model = spm.SentencePieceProcessor(
enable_sampling=enable_sampling,
nbest_size=nbest_size,
alpha=alpha,
)
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file

Expand Down Expand Up @@ -252,7 +274,7 @@ def get_vocab(self):
return vocab

def _tokenize(self, text):
return self.sp_model.EncodeAsPieces(text)
return self.sp_model.encode(text, out_type=str)

def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
Expand Down