Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable option for subword regularization in XLMRobertaTokenizer #11149

Merged
merged 15 commits into from
Apr 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
modeling. This is the token which the model will try to predict.
additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`["<s>NOTUSED", "</s>NOTUSED"]`):
Additional special tokens used by the tokenizer.
sp_model_kwargs (:obj:`dict`, `optional`, defaults to :obj:`None`):
Will be passed to the ``SentencePieceProcessor.__init__()`` method. The `Python wrapper for SentencePiece
<https://github.com/google/sentencepiece/tree/master/python>`__ can be used, among other things, to set:

- ``enable_sampling``: Enable subword regularization.
- ``nbest_size``: Sampling parameters for unigram. Invalid for BPE-Dropout.

- ``nbest_size = {0,1}``: No sampling is performed.
- ``nbest_size > 1``: samples from the nbest_size results.
- ``nbest_size < 0``: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.

- ``alpha``: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.

Attributes:
sp_model (:obj:`SentencePieceProcessor`):
Expand All @@ -115,11 +129,14 @@ def __init__(
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
sp_model_kwargs=None,
**kwargs
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

super().__init__(
bos_token=bos_token,
eos_token=eos_token,
Expand All @@ -128,10 +145,11 @@ def __init__(
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
sp_model_kwargs=sp_model_kwargs,
**kwargs,
)

self.sp_model = spm.SentencePieceProcessor()
self.sp_model = spm.SentencePieceProcessor(**sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file

Expand Down Expand Up @@ -252,7 +270,7 @@ def get_vocab(self):
return vocab

def _tokenize(self, text):
return self.sp_model.EncodeAsPieces(text)
return self.sp_model.encode(text, out_type=str)

def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
Expand Down
24 changes: 24 additions & 0 deletions tests/test_tokenization_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.


import itertools
import os
import unittest

Expand Down Expand Up @@ -118,6 +119,29 @@ def test_full_tokenizer(self):
],
)

def test_subword_regularization_tokenizer(self):
# Subword regularization is only available for the slow tokenizer.
tokenizer = XLMRobertaTokenizer(
SAMPLE_VOCAB, keep_accents=True, sp_model_kwargs={"enable_sampling": True, "alpha": 0.1, "nbest_size": -1}
)

# Subword regularization augments training data with subword sampling.
# This has a random component. We test if the tokenizer generates different
# results when subword regularization is enabled.
tokens_list = []
for _ in range(5):
tokens_list.append(tokenizer.tokenize("This is a test for subword regularization."))

# the list of different pairs of tokens_list
combinations = itertools.combinations(tokens_list, 2)

all_equal = True
for combination in combinations:
if combination[0] != combination[1]:
all_equal = False

self.assertFalse(all_equal)

@cached_property
def big_tokenizer(self):
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
Expand Down