Skip to content

Commit

Permalink
Improve the speed of adding tokens from added_tokens.json (huggingfac…
Browse files Browse the repository at this point in the history
…e#10780)

* use bisect to add one token to unique_no_split_tokens

* fix style
  • Loading branch information
cchen-dialpad authored and Iwontbecreative committed Jul 15, 2021
1 parent f2298e9 commit 34e3144
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see
tokenization_utils_fast.py
"""
import bisect
import itertools
import re
import unicodedata
Expand Down Expand Up @@ -99,6 +100,19 @@ def _is_start_of_word(text):
return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))


def _insert_one_token_to_ordered_list(token_list: List[str], new_token: str):
"""
Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted.
"""
insertion_idx = bisect.bisect_left(token_list, new_token)
# Checks if new_token is already in the ordered token_list
if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token:
# new_token is in token_list, don't add
return
else:
token_list.insert(insertion_idx, new_token)


@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
class PreTrainedTokenizer(PreTrainedTokenizerBase):
"""
Expand Down Expand Up @@ -199,10 +213,16 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to

# Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
if special_tokens:
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(new_tokens)))
if len(new_tokens) == 1:
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, new_tokens[0])
else:
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(new_tokens)))
else:
# Or on the newly added tokens
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
if len(tokens_to_add) == 1:
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0])
else:
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))

return len(tokens_to_add)

Expand Down

0 comments on commit 34e3144

Please sign in to comment.