Skip to content

Commit

Permalink
Add option to use fast HF tokenizer. (#482)
Browse files Browse the repository at this point in the history
* Add option to use fast HF tokenizer

* Hand merge tests from PR #205

* test_inferencer_with_fast_bert_tokenizer

* test_fast_bert_tokenizer

* test_fast_bert_tokenizer_strip_accents

* test_fast_electra_tokenizer

* Fix OOM issue of CI

- set num_processes=0 for Inferencer

* Extend test for fast tokenizer

- electra
- roberta

* test_fast_tokenizer for more model typed

- electra
- roberta

* Fix tokenize_with_metadata

* Split tokenizer tests

* Fix pytest params bug in test_tok

* Fix fast tokenizer usage

* add missing newline eof

* Add test fast tok. doc_callif.

* Remove RobertaTokenizerFast

* Fix Tokenizer load and save.

* Fix typo

* Improve test test_embeddings_extraction

- add shape assert
- fix embedding assert

* Dosctring for fast tokenizers improved

* tokenizer_args docstring

* Extend test_embeddings_extraction to fast tok.

* extend test_ner with fast tok.

* fix sample_to_features_ner for fast tokenizer

* temp fix for is_pretokenized until fixed upstream

* Make use of fast tokenizer possible + fix bug in offset calculation

* Make fast tokenization possible with NER, LM and QA

* Change error messages

* Add tests

* update error messages, comments and truncation arg in tokenizer

Co-authored-by: Malte Pietsch <[email protected]>
Co-authored-by: Bogdan Kostić <[email protected]>
  • Loading branch information
3 people authored Sep 2, 2020
1 parent dd3945d commit 435f3ee
Show file tree
Hide file tree
Showing 10 changed files with 420 additions and 93 deletions.
162 changes: 126 additions & 36 deletions farm/data_handler/input_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


import logging
import re
import collections
from dotmap import DotMap
import numpy as np
Expand Down Expand Up @@ -36,18 +37,34 @@ def sample_to_features_text(
:rtype: list
"""

#TODO It might be cleaner to adjust the data structure in sample.tokenized
# Verify if this current quickfix really works for pairs
tokens_a = sample.tokenized["tokens"]
tokens_b = sample.tokenized.get("tokens_b", None)

inputs = tokenizer.encode_plus(
tokens_a,
tokens_b,
add_special_tokens=True,
truncation_strategy='do_not_truncate',
return_token_type_ids=True
)
if tokenizer.is_fast:
text = sample.clear_text["text"]
# Here, we tokenize the sample for the second time to get all relevant ids
# This should change once we git rid of FARM's tokenize_with_metadata()
inputs = tokenizer(text,
return_token_type_ids=True,
max_length=max_seq_len,
return_special_tokens_mask=True)

if (len(inputs["input_ids"]) - inputs["special_tokens_mask"].count(1)) != len(sample.tokenized["tokens"]):
logger.error(f"FastTokenizer encoded sample {sample.clear_text['text']} to "
f"{len(inputs['input_ids']) - inputs['special_tokens_mask'].count(1)} tokens, which differs "
f"from number of tokens produced in tokenize_with_metadata(). \n"
f"Further processing is likely to be wrong.")
else:
# TODO It might be cleaner to adjust the data structure in sample.tokenized
tokens_a = sample.tokenized["tokens"]
tokens_b = sample.tokenized.get("tokens_b", None)

inputs = tokenizer.encode_plus(
tokens_a,
tokens_b,
add_special_tokens=True,
truncation=False, # truncation_strategy is deprecated
return_token_type_ids=True,
max_length=max_seq_len,
is_pretokenized=False,
)

input_ids, segment_ids = inputs["input_ids"], inputs["token_type_ids"]

Expand Down Expand Up @@ -136,13 +153,30 @@ def samples_to_features_ner(
"""

tokens = sample.tokenized["tokens"]
inputs = tokenizer.encode_plus(text=tokens,
text_pair=None,
add_special_tokens=True,
truncation_strategy='do_not_truncate', # We've already truncated our tokens before
return_special_tokens_mask=True,
return_token_type_ids=True
)

if tokenizer.is_fast:
text = sample.clear_text["text"]
# Here, we tokenize the sample for the second time to get all relevant ids
# This should change once we git rid of FARM's tokenize_with_metadata()
inputs = tokenizer(text,
return_token_type_ids=True,
max_length=max_seq_len,
return_special_tokens_mask=True)

if (len(inputs["input_ids"]) - inputs["special_tokens_mask"].count(1)) != len(sample.tokenized["tokens"]):
logger.error(f"FastTokenizer encoded sample {sample.clear_text['text']} to "
f"{len(inputs['input_ids']) - inputs['special_tokens_mask'].count(1)} tokens, which differs "
f"from number of tokens produced in tokenize_with_metadata().\n"
f"Further processing is likely to be wrong!")
else:
inputs = tokenizer.encode_plus(text=tokens,
text_pair=None,
add_special_tokens=True,
truncation=False,
return_special_tokens_mask=True,
return_token_type_ids=True,
is_pretokenized=False
)

input_ids, segment_ids, special_tokens_mask = inputs["input_ids"], inputs["token_type_ids"], inputs["special_tokens_mask"]

Expand Down Expand Up @@ -231,6 +265,14 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T

tokens_b, t2_label = mask_random_words(tokens_b, tokenizer.vocab,
token_groups=sample.tokenized["text_b"]["start_of_word"])

if tokenizer.is_fast:
# Detokenize input as fast tokenizer can't handle tokenized input
tokens_a = " ".join(tokens_a)
tokens_a = re.sub(r"(^|\s)(##)", "", tokens_a)
tokens_b = " ".join(tokens_b)
tokens_b = re.sub(r"(^|\s)(##)", "", tokens_b)

# convert lm labels to ids
t1_label_ids = [-1 if tok == '' else tokenizer.convert_tokens_to_ids(tok) for tok in t1_label]
t2_label_ids = [-1 if tok == '' else tokenizer.convert_tokens_to_ids(tok) for tok in t2_label]
Expand All @@ -246,18 +288,39 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T
tokens_b = None
tokens_a, t1_label = mask_random_words(tokens_a, tokenizer.vocab,
token_groups=sample.tokenized["text_a"]["start_of_word"])
if tokenizer.is_fast:
# Detokenize input as fast tokenizer can't handle tokenized input
tokens_a = " ".join(tokens_a)
tokens_a = re.sub(r"(^|\s)(##)", "", tokens_a)

# convert lm labels to ids
lm_label_ids = [-1 if tok == '' else tokenizer.convert_tokens_to_ids(tok) for tok in t1_label]

# encode string tokens to input_ids and add special tokens
inputs = tokenizer.encode_plus(text=tokens_a,
text_pair=tokens_b,
add_special_tokens=True,
truncation_strategy='do_not_truncate',
# We've already truncated our tokens before
return_special_tokens_mask=True,
return_token_type_ids=True
)
if tokenizer.is_fast:
inputs = tokenizer(text=tokens_a,
text_pair=tokens_b,
add_special_tokens=True,
return_special_tokens_mask=True,
return_token_type_ids=True)

seq_b_len = len(sample.tokenized["text_b"]["tokens"]) if "text_b" in sample.tokenized else 0
if (len(inputs["input_ids"]) - inputs["special_tokens_mask"].count(1)) != \
(len(sample.tokenized["text_a"]["tokens"]) + seq_b_len):
logger.error(f"FastTokenizer encoded sample {sample.clear_text['text']} to "
f"{len(inputs['input_ids']) - inputs['special_tokens_mask'].count(1)} tokens, which differs "
f"from number of tokens produced in tokenize_with_metadata(). \n"
f"Further processing is likely to be wrong.")
else:
# encode string tokens to input_ids and add special tokens
inputs = tokenizer.encode_plus(text=tokens_a,
text_pair=tokens_b,
add_special_tokens=True,
truncation=False,
truncation_strategy='do_not_truncate',
# We've already truncated our tokens before
return_special_tokens_mask=True,
return_token_type_ids=True
)

input_ids, segment_ids, special_tokens_mask = inputs["input_ids"], inputs["token_type_ids"], inputs[
"special_tokens_mask"]
Expand Down Expand Up @@ -358,12 +421,35 @@ def sample_to_features_qa(sample, tokenizer, max_seq_len, sp_toks_start, sp_toks
# (question_len_t + passage_len_t + n_special_tokens). This may be less than max_seq_len but will not be greater
# than max_seq_len since truncation was already performed when the document was chunked into passages
# (c.f. create_samples_squad() )
encoded = tokenizer.encode_plus(text=sample.tokenized["question_tokens"],
text_pair=sample.tokenized["passage_tokens"],
add_special_tokens=True,
truncation_strategy='do_not_truncate',
return_token_type_ids=True,
return_tensors=None)

if tokenizer.is_fast:
# Detokenize input as fast tokenizer can't handle tokenized input
question_tokens = " ".join(question_tokens)
question_tokens = re.sub(r"(^|\s)(##)", "", question_tokens)
passage_tokens = " ".join(passage_tokens)
passage_tokens = re.sub(r"(^|\s)(##)", "", passage_tokens)

encoded = tokenizer(text=question_tokens,
text_pair=passage_tokens,
add_special_tokens=True,
return_special_tokens_mask=True,
return_token_type_ids=True)

if (len(encoded["input_ids"]) - encoded["special_tokens_mask"].count(1)) != \
(len(sample.tokenized["question_tokens"]) + len(sample.tokenized["passage_tokens"])):
logger.error(f"FastTokenizer encoded sample {sample.clear_text['text']} to "
f"{len(encoded['input_ids']) - encoded['special_tokens_mask'].count(1)} tokens, which differs "
f"from number of tokens produced in tokenize_with_metadata(). \n"
f"Further processing is likely to be wrong.")
else:
encoded = tokenizer.encode_plus(text=sample.tokenized["question_tokens"],
text_pair=sample.tokenized["passage_tokens"],
add_special_tokens=True,
truncation=False,
truncation_strategy='do_not_truncate',
return_token_type_ids=True,
return_tensors=None)

input_ids = encoded["input_ids"]
segment_ids = encoded["token_type_ids"]

Expand Down Expand Up @@ -467,8 +553,12 @@ def combine_vecs(question_vec, passage_vec, tokenizer, spec_tok_val=-1):
# Join question_label_vec and passage_label_vec and add slots for special tokens
vec = tokenizer.build_inputs_with_special_tokens(token_ids_0=question_vec,
token_ids_1=passage_vec)
spec_toks_mask = tokenizer.get_special_tokens_mask(token_ids_0=question_vec,
token_ids_1=passage_vec)
if tokenizer.is_fast:
spec_toks_mask = tokenizer.get_special_tokens_mask(token_ids_0=vec,
already_has_special_tokens=True)
else:
spec_toks_mask = tokenizer.get_special_tokens_mask(token_ids_0=question_vec,
token_ids_1=passage_vec)

# If a value in vec corresponds to a special token, it will be replaced with spec_tok_val
combined = [v if not special_token else spec_tok_val for v, special_token in zip(vec, spec_toks_mask)]
Expand Down
6 changes: 5 additions & 1 deletion farm/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@ def save(self, save_dir):
config = self.generate_config()
# save tokenizer incl. attributes
config["tokenizer"] = self.tokenizer.__class__.__name__
self.tokenizer.save_pretrained(save_dir)

# Because the fast tokenizers expect a str and not Path
# always convert Path to str here.
self.tokenizer.save_pretrained(str(save_dir))

# save processor
config["processor"] = self.__class__.__name__
output_config_file = Path(save_dir) / "processor_config.json"
Expand Down
20 changes: 19 additions & 1 deletion farm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def load(
s3e_stats=None,
num_processes=None,
disable_tqdm=False,
tokenizer_class=None,
use_fast=False,
tokenizer_args=None,
dummy_ph=False,
benchmarking=False,

Expand Down Expand Up @@ -212,6 +215,15 @@ def load(
:type num_processes: int
:param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
:type disable_tqdm: bool
:param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
:type tokenizer_class: str
:param use_fast: (Optional, False by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
use the Python one (False).
:param tokenizer_args: (Optional) Will be passed to the Tokenizer ``__init__`` method.
See https://huggingface.co/transformers/main_classes/tokenizer.html and detailed tokenizer documentation
on `Hugging Face Transformers <https://huggingface.co/transformers/>`_.
:type tokenizer_args: dict
:type use_fast: bool
:param dummy_ph: If True, methods of the prediction head will be replaced
with a dummy method. This is used to isolate lm run time from ph run time.
:type dummy_ph: bool
Expand All @@ -223,6 +235,8 @@ def load(
:return: An instance of the Inferencer.
"""
if tokenizer_args is None:
tokenizer_args = {}

device, n_gpu = initialize_device_settings(use_cuda=gpu, local_rank=-1, use_amp=None)
name = os.path.basename(model_name_or_path)
Expand Down Expand Up @@ -250,7 +264,11 @@ def load(

model = AdaptiveModel.convert_from_transformers(model_name_or_path, device, task_type)
config = AutoConfig.from_pretrained(model_name_or_path)
tokenizer = Tokenizer.load(model_name_or_path)
tokenizer = Tokenizer.load(model_name_or_path,
tokenizer_class=tokenizer_class,
use_fast=use_fast,
**tokenizer_args,
)

# TODO infer task_type automatically from config (if possible)
if task_type == "question_answering":
Expand Down
Loading

0 comments on commit 435f3ee

Please sign in to comment.