Skip to content

Commit

Permalink
max_len_single_sentence & max_len_sentences_pair as attributes so the…
Browse files Browse the repository at this point in the history
…y can be modified
  • Loading branch information
thomwolf committed Aug 23, 2019
1 parent ab7bd5e commit 3bcbebd
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 40 deletions.
11 changes: 3 additions & 8 deletions pytorch_transformers/tokenization_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens

if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
Expand All @@ -139,14 +142,6 @@ def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never
tokenize_chinese_chars=tokenize_chinese_chars)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)

@property
def max_len_single_sentence(self):
return self.max_len - 2 # take into account special tokens

@property
def max_len_sentences_pair(self):
return self.max_len - 3 # take into account special tokens

@property
def vocab_size(self):
return len(self.vocab)
Expand Down
2 changes: 2 additions & 0 deletions pytorch_transformers/tokenization_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>",
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens

self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
Expand Down
3 changes: 3 additions & 0 deletions pytorch_transformers/tokenization_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)

self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens

try:
import ftfy
from spacy.lang.en import English
Expand Down
11 changes: 3 additions & 8 deletions pytorch_transformers/tokenization_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __init__(self, vocab_file, merges_file, errors='replace', bos_token="<s>", e
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, **kwargs)

self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens

self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
Expand Down Expand Up @@ -160,14 +163,6 @@ def convert_tokens_to_string(self, tokens):
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text

@property
def max_len_single_sentence(self):
return self.max_len - 2 # take into account special tokens

@property
def max_len_sentences_pair(self):
return self.max_len - 4 # take into account special tokens

def add_special_tokens_single_sentence(self, token_ids):
"""
Adds special tokens to a sequence for sequence classification tasks.
Expand Down
4 changes: 4 additions & 0 deletions pytorch_transformers/tokenization_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __init__(self, special=None, min_freq=0, max_size=None, lower_case=False,
super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token,
additional_special_tokens=additional_special_tokens,
**kwargs)

self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens

if never_split is None:
never_split = self.all_special_tokens
if special is None:
Expand Down
11 changes: 3 additions & 8 deletions pytorch_transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,6 @@ class PreTrainedTokenizer(object):
"pad_token", "cls_token", "mask_token",
"additional_special_tokens"]

@property
def max_len_single_sentence(self):
return self.max_len # Default to max_len but can be smaller in specific tokenizers to take into account special tokens

@property
def max_len_sentences_pair(self):
return self.max_len # Default to max_len but can be smaller in specific tokenizers to take into account special tokens

@property
def bos_token(self):
""" Beginning of sentence token (string). Log an error if used while not having been set. """
Expand Down Expand Up @@ -174,6 +166,9 @@ def __init__(self, max_len=None, **kwargs):
self._additional_special_tokens = []

self.max_len = max_len if max_len is not None else int(1e12)
self.max_len_single_sentence = self.max_len
self.max_len_sentences_pair = self.max_len

self.added_tokens_encoder = {}
self.added_tokens_decoder = {}

Expand Down
12 changes: 4 additions & 8 deletions pytorch_transformers/tokenization_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
cls_token=cls_token, mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs)

self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens

try:
import ftfy
from spacy.lang.en import English
Expand Down Expand Up @@ -215,14 +219,6 @@ def convert_tokens_to_string(self, tokens):
out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string

@property
def max_len_single_sentence(self):
return self.max_len - 2 # take into account special tokens

@property
def max_len_sentences_pair(self):
return self.max_len - 3 # take into account special tokens

def add_special_tokens_single_sentence(self, token_ids):
"""
Adds special tokens to a sequence for sequence classification tasks.
Expand Down
12 changes: 4 additions & 8 deletions pytorch_transformers/tokenization_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __init__(self, vocab_file, max_len=None,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, additional_special_tokens=
additional_special_tokens, **kwargs)

self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens

try:
import sentencepiece as spm
except ImportError:
Expand Down Expand Up @@ -177,14 +181,6 @@ def convert_tokens_to_string(self, tokens):
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string

@property
def max_len_single_sentence(self):
return self.max_len - 2 # take into account special tokens

@property
def max_len_sentences_pair(self):
return self.max_len - 3 # take into account special tokens

def add_special_tokens_single_sentence(self, token_ids):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
Expand Down

0 comments on commit 3bcbebd

Please sign in to comment.