Skip to content

Commit

Permalink
GH-1494: fix handling of unknown tokens and RoBERTa offsets
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Apr 15, 2020
1 parent 323d60b commit 3b78789
Showing 1 changed file with 38 additions and 7 deletions.
45 changes: 38 additions & 7 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2231,7 +2231,7 @@ def __init__(
self.model = AutoModel.from_pretrained(model, config=config)

# model name
self.name = str(model)
self.name = 'transformer-word-' + str(model)

# when initializing, embeddings are in eval mode by default
self.model.eval()
Expand All @@ -2251,11 +2251,11 @@ def __init__(

# most models have an intial BOS token, except for XLNet, T5 and GPT2
self.begin_offset = 1
if isinstance(self.tokenizer, XLNetTokenizer):
if type(self.tokenizer) == XLNetTokenizer:
self.begin_offset = 0
if isinstance(self.tokenizer, T5Tokenizer):
if type(self.tokenizer) == T5Tokenizer:
self.begin_offset = 0
if isinstance(self.tokenizer, GPT2Tokenizer):
if type(self.tokenizer) == GPT2Tokenizer:
self.begin_offset = 0

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
Expand All @@ -2280,13 +2280,22 @@ def _add_embeddings_to_sentences(self, sentences: List[Sentence]):

for sentence in sentences:

# subtokenize sentence
subtokenized_sentence = self.tokenizer.encode(sentence.to_tokenized_string(), add_special_tokens=True)
tokenized_string = sentence.to_tokenized_string()

# method 1: subtokenize sentence
subtokenized_sentence = self.tokenizer.encode(tokenized_string, add_special_tokens=True)

# method 2:
# ids = self.tokenizer.encode(tokenized_string, add_special_tokens=False)
# subtokenized_sentence = self.tokenizer.build_inputs_with_special_tokens(ids)

subtokenized_sentences.append(torch.tensor(subtokenized_sentence, dtype=torch.long))
subtokens = self.tokenizer.convert_ids_to_tokens(subtokenized_sentence)
# print(subtokens)

word_iterator = iter(sentence)
token = next(word_iterator)
token_text = token.text.lower()

token_subtoken_lengths = []
reconstructed_token = ''
Expand All @@ -2306,13 +2315,34 @@ def _add_embeddings_to_sentences(self, sentences: List[Sentence]):
# append subtoken to reconstruct token
reconstructed_token = reconstructed_token + subtoken

# print(reconstructed_token)

# check if reconstructed token is special begin token ([CLS] or similar)
if reconstructed_token in self.special_tokens and subtoken_id == 0:
reconstructed_token = ''
subtoken_count = 0

# special handling for UNK subtokens
if self.tokenizer.unk_token and self.tokenizer.unk_token in reconstructed_token:
pieces = self.tokenizer.convert_ids_to_tokens(
self.tokenizer.encode(token.text, add_special_tokens=False))
token_text = ''
for piece in pieces:
# remove special markup
piece = re.sub('^Ġ', '', piece) # RoBERTa models
piece = re.sub('^##', '', piece) # BERT models
piece = re.sub('^▁', '', piece) # XLNet models
piece = re.sub('</w>$', '', piece) # XLM models
token_text += piece
token_text = token_text.lower()

# check if reconstructed token is the same as current token
if reconstructed_token.lower() == token.text.lower():
if reconstructed_token.lower() == token_text:

# print(token)
# print(reconstructed_token)
# print(subtoken_count)
# print()

# if so, add subtoken count
token_subtoken_lengths.append(subtoken_count)
Expand All @@ -2324,6 +2354,7 @@ def _add_embeddings_to_sentences(self, sentences: List[Sentence]):
# break from loop if all tokens are accounted for
if len(token_subtoken_lengths) < len(sentence):
token = next(word_iterator)
token_text = token.text.lower()
else:
break

Expand Down

0 comments on commit 3b78789

Please sign in to comment.