Skip to content

Commit

Permalink
GH-1492: added new BERT embeddings implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kishaloy Halder committed Mar 26, 2020
1 parent f72eb96 commit 5a081a2
Showing 1 changed file with 75 additions and 159 deletions.
234 changes: 75 additions & 159 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,13 +2198,15 @@ def __setstate__(self, d):
self.word_embeddings[key] = self.word_embeddings[key].cpu()


class BertEmbeddings(TokenEmbeddings):
class BertEmbeddingsNew(DocumentEmbeddings, TokenEmbeddings):
def __init__(
self,
bert_model_or_path: str = "bert-base-uncased",
layers: str = "-1,-2,-3,-4",
pooling_operation: str = "first",
use_scalar_mix: bool = False,
fine_tune: bool = False,
document_only = False
):
"""
Bidirectional transformer embeddings of words, as proposed in Devlin et al., 2018.
Expand All @@ -2213,199 +2215,113 @@ def __init__(
:param layers: string indicating which layers to take for embedding
:param pooling_operation: how to get from token piece embeddings to token embedding. Either pool them and take
the average ('mean') or use first word piece embedding as token embedding ('first)
:param document_only: set only document (sentence) emebddings
"""
super().__init__()

if "distilbert" in bert_model_or_path:
try:
from transformers import DistilBertTokenizer, DistilBertModel
except ImportError:
log.warning("-" * 100)
log.warning(
"ATTENTION! To use DistilBert, please first install a recent version of transformers!"
)
log.warning("-" * 100)
pass

self.tokenizer = DistilBertTokenizer.from_pretrained(bert_model_or_path)
self.model = DistilBertModel.from_pretrained(
pretrained_model_name_or_path=bert_model_or_path,
output_hidden_states=True,
)
self.model = DistilBertModel.from_pretrained(bert_model_or_path)
elif "albert" in bert_model_or_path:
self.tokenizer = AlbertTokenizer.from_pretrained(bert_model_or_path)
self.model = AlbertModel.from_pretrained(
pretrained_model_name_or_path=bert_model_or_path,
output_hidden_states=True,
)
self.model = AlbertModel.from_pretrained(bert_model_or_path)

This comment has been minimized.

Copy link
@stefan-it

stefan-it Mar 26, 2020

Member

DistilBERTModel and AlbertModel are missing the output_hidden_states parameters :)

else:
self.tokenizer = BertTokenizer.from_pretrained(bert_model_or_path)
self.model = BertModel.from_pretrained(
pretrained_model_name_or_path=bert_model_or_path,
output_hidden_states=True,
)
self.layer_indexes = [int(x) for x in layers.split(",")]
self.pooling_operation = pooling_operation
self.use_scalar_mix = use_scalar_mix
self.name = str(bert_model_or_path)
self.static_embeddings = True
self.model = BertModel.from_pretrained(bert_model_or_path,
output_hidden_states=True)

class BertInputFeatures(object):
"""Private helper class for holding BERT-formatted features"""

def __init__(
self,
unique_id,
tokens,
input_ids,
input_mask,
input_type_ids,
token_subtoken_count,
):
self.unique_id = unique_id
self.tokens = tokens
self.input_ids = input_ids
self.input_mask = input_mask
self.input_type_ids = input_type_ids
self.token_subtoken_count = token_subtoken_count
# when initializing, embeddings are in eval mode by default
self.model.eval()

def _convert_sentences_to_features(
self, sentences, max_sequence_length: int
) -> [BertInputFeatures]:
self.layer_indexes = [int(x) for x in layers.split(",")]

max_sequence_length = max_sequence_length + 2
self.pooling_operation = pooling_operation

features: List[BertEmbeddings.BertInputFeatures] = []
for (sentence_index, sentence) in enumerate(sentences):
self.use_scalar_mix = use_scalar_mix

bert_tokenization: List[str] = []
token_subtoken_count: Dict[int, int] = {}
self.name = str(bert_model_or_path)

for token in sentence:
subtokens = self.tokenizer.tokenize(token.text)
bert_tokenization.extend(subtokens)
token_subtoken_count[token.idx] = len(subtokens)

if len(bert_tokenization) > max_sequence_length - 2:
bert_tokenization = bert_tokenization[0 : (max_sequence_length - 2)]

tokens = []
input_type_ids = []
tokens.append("[CLS]")
input_type_ids.append(0)
for token in bert_tokenization:
tokens.append(token)
input_type_ids.append(0)
tokens.append("[SEP]")
input_type_ids.append(0)

input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)

# Zero-pad up to the sequence length.
while len(input_ids) < max_sequence_length:
input_ids.append(0)
input_mask.append(0)
input_type_ids.append(0)

features.append(
BertEmbeddings.BertInputFeatures(
unique_id=sentence_index,
tokens=tokens,
input_ids=input_ids,
input_mask=input_mask,
input_type_ids=input_type_ids,
token_subtoken_count=token_subtoken_count,
)
)
self.fine_tune = fine_tune
self.static_embeddings = not self.fine_tune

return features
self.document_only = document_only

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
"""Add embeddings to all words in a list of sentences. If embeddings are already added,
updates only if embeddings are non-static."""

# prepare id maps for BERT model
tokenized = []
for sentence in sentences:
tokenized.append(torch.tensor(self.tokenizer.encode(sentence.to_tokenized_string(), add_special_tokens=True), dtype=torch.long))

# first, find longest sentence in batch
longest_sentence_in_batch: int = len(
max(
[
self.tokenizer.tokenize(sentence.to_tokenized_string())
for sentence in sentences
],
key=len,
)
)
longest_sequence_in_batch: int = len(max(tokenized, key=len))

# prepare id maps for BERT model
features = self._convert_sentences_to_features(
sentences, longest_sentence_in_batch
)
all_input_ids = torch.LongTensor([f.input_ids for f in features]).to(
flair.device
# initialize batch tensors
input_ids = torch.zeros(
[len(sentences), longest_sequence_in_batch],
dtype=torch.long,
device=flair.device,
)
all_input_masks = torch.LongTensor([f.input_mask for f in features]).to(
flair.device
mask = torch.zeros(
[len(sentences), longest_sequence_in_batch],
dtype=torch.long,
device=flair.device,
)

# put encoded batch through BERT model to get all hidden states of all encoder layers
self.model.to(flair.device)
self.model.eval()
all_encoder_layers = self.model(all_input_ids, attention_mask=all_input_masks)[
-1
]
for s_id, sentence in enumerate(tokenized):
sequence_length = len(sentence)
input_ids[s_id][:sequence_length] = sentence
mask[s_id][:sequence_length] = torch.ones(sequence_length)


with torch.no_grad():
# put encoded batch through BERT model to get all hidden states of all encoder layers
all_encoder_layers = self.model(input_ids, attention_mask=mask)[-1]

for sentence_index, sentence in enumerate(sentences):
# gradients are enable if fine-tuning is enabled
gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad()

feature = features[sentence_index]
with gradient_context:
for sentence_idx, (subtoken_ids, sentence) in enumerate(zip(tokenized, sentences)):
subtokens = self.tokenizer.convert_ids_to_tokens(subtoken_ids)

token_subtoken_embeddings = []
token_idx = 0
for subtoken_idx, subtoken in enumerate(subtokens):
if subtoken == '[SEP]':
continue

# get aggregated embeddings for each BERT-subtoken in sentence
subtoken_embeddings = []
for token_index, _ in enumerate(feature.tokens):
all_layers = []
for layer_index in self.layer_indexes:
if self.use_scalar_mix:
layer_output = all_encoder_layers[int(layer_index)][
sentence_index
]
# check if this token initiates a new word, internal subtokens start with ##
if not subtoken.startswith('##') and len(token_subtoken_embeddings) > 0:
if token_idx == 0:
sentence.set_embedding(self.name, token_subtoken_embeddings[0])
else:
layer_output = all_encoder_layers[int(layer_index)][
sentence_index
]
all_layers.append(layer_output[token_index])
token: Token = sentence[token_idx - 1]
if self.pooling_operation == "first":
token_embedding = token_subtoken_embeddings[0]
else:
token_embedding = torch.mean(torch.stack(token_subtoken_embeddings, dim=0), dim=0)
token.set_embedding(self.name, token_embedding)
# clear subtoken embedding and go to next token
token_subtoken_embeddings = []
token_idx += 1

if token_idx >= 1 and self.document_only:
break

# get embedding of subtoken
all_layers = [all_encoder_layers[int(layer_index)][sentence_idx, subtoken_idx]
for layer_index in self.layer_indexes]

if self.use_scalar_mix:
sm = ScalarMix(mixture_size=len(all_layers))
sm_embeddings = sm(all_layers)
all_layers = [sm_embeddings]

subtoken_embeddings.append(torch.cat(all_layers))

# get the current sentence object
token_idx = 0
for token in sentence:
# add concatenated embedding to sentence
token_idx += 1

if self.pooling_operation == "first":
# use first subword embedding if pooling operation is 'first'
token.set_embedding(self.name, subtoken_embeddings[token_idx])
embedding = sm(all_layers)
else:
# otherwise, do a mean over all subwords in token
embeddings = subtoken_embeddings[
token_idx : token_idx
+ feature.token_subtoken_count[token.idx]
]
embeddings = [
embedding.unsqueeze(0) for embedding in embeddings
]
mean = torch.mean(torch.cat(embeddings, dim=0), dim=0)
token.set_embedding(self.name, mean)

token_idx += feature.token_subtoken_count[token.idx] - 1
embedding = torch.cat(all_layers, dim=0)

# add to list of embeddings
token_subtoken_embeddings.append(embedding)

return sentences

Expand Down

0 comments on commit 5a081a2

Please sign in to comment.