Skip to content

Commit

Permalink
Remove RobertaTokenizerFast
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipMay committed Aug 4, 2020
1 parent 8c61e3b commit aec7d2d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
7 changes: 4 additions & 3 deletions farm/modeling/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from transformers.tokenization_bert import BertTokenizer, BertTokenizerFast, load_vocab
from transformers.tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
from transformers.tokenization_electra import ElectraTokenizer, ElectraTokenizerFast
from transformers.tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from transformers.tokenization_roberta import RobertaTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_xlm_roberta import XLMRobertaTokenizer
from transformers.tokenization_xlnet import XLNetTokenizer
Expand Down Expand Up @@ -59,6 +59,7 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=Fals
:type tokenizer_class: str
:param use_fast: (Optional, False by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
use the Python one (False).
TODO: Say which models support fast tokenizers.
:type use_fast: bool
:param kwargs:
:return: Tokenizer
Expand Down Expand Up @@ -102,7 +103,7 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=Fals
ret = XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif tokenizer_class == "RobertaTokenizer":
if use_fast:
ret = RobertaTokenizerFast.from_pretrained(pretrained_model_name_or_path, **kwargs)
raise ValueError('RobertaTokenizerFast is not supportet!')
else:
ret = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif tokenizer_class == "DistilBertTokenizer":
Expand Down Expand Up @@ -309,7 +310,7 @@ def _words_to_tokens(words, word_offsets, tokenizer):
elif len(tokens) == 0:
tokens_word = tokenizer.tokenize(w)
else:
if (type(tokenizer) == RobertaTokenizer) or (type(tokenizer) == RobertaTokenizerFast):
if type(tokenizer) == RobertaTokenizer:
tokens_word = tokenizer.tokenize(w, add_prefix_space=True)
else:
tokens_word = tokenizer.tokenize(w)
Expand Down
5 changes: 1 addition & 4 deletions test/test_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import re
from transformers import BertTokenizer, BertTokenizerFast, RobertaTokenizer, XLNetTokenizer
from transformers import ElectraTokenizerFast, RobertaTokenizerFast
from transformers import ElectraTokenizerFast

from farm.modeling.tokenization import Tokenizer, tokenize_with_metadata, truncate_sequences

Expand Down Expand Up @@ -116,7 +116,6 @@ def test_truncate_sequences(caplog):

@pytest.mark.parametrize("model_name", ["bert-base-german-cased",
"google/electra-small-discriminator",
"distilroberta-base",
])
def test_fast_tokenizer_with_examples(caplog, model_name):
fast_tokenizer = Tokenizer.load(model_name, lower_case=False, use_fast=True)
Expand All @@ -132,7 +131,6 @@ def test_fast_tokenizer_with_examples(caplog, model_name):

@pytest.mark.parametrize("model_name", ["bert-base-german-cased",
"google/electra-small-discriminator",
"distilroberta-base",
])
def test_fast_tokenizer_with_metadata_with_examples(caplog, model_name):
fast_tokenizer = Tokenizer.load(model_name, lower_case=False, use_fast=True)
Expand Down Expand Up @@ -259,7 +257,6 @@ def test_fast_bert_custom_vocab(caplog):
@pytest.mark.parametrize("model_name, tokenizer_type", [
("bert-base-german-cased", BertTokenizerFast),
("google/electra-small-discriminator", ElectraTokenizerFast),
("distilroberta-base", RobertaTokenizerFast),
])
def test_fast_tokenizer_type(caplog, model_name, tokenizer_type):
caplog.set_level(logging.CRITICAL)
Expand Down

0 comments on commit aec7d2d

Please sign in to comment.