diff --git a/flair/embeddings.py b/flair/embeddings.py index bc5cec02b6..fc96c9a75d 100644 --- a/flair/embeddings.py +++ b/flair/embeddings.py @@ -2198,13 +2198,15 @@ def __setstate__(self, d): self.word_embeddings[key] = self.word_embeddings[key].cpu() -class BertEmbeddings(TokenEmbeddings): +class BertEmbeddingsNew(DocumentEmbeddings, TokenEmbeddings): def __init__( self, bert_model_or_path: str = "bert-base-uncased", layers: str = "-1,-2,-3,-4", pooling_operation: str = "first", use_scalar_mix: bool = False, + fine_tune: bool = False, + document_only = False ): """ Bidirectional transformer embeddings of words, as proposed in Devlin et al., 2018. @@ -2213,199 +2215,113 @@ def __init__( :param layers: string indicating which layers to take for embedding :param pooling_operation: how to get from token piece embeddings to token embedding. Either pool them and take the average ('mean') or use first word piece embedding as token embedding ('first) + :param document_only: set only document (sentence) emebddings """ super().__init__() if "distilbert" in bert_model_or_path: - try: - from transformers import DistilBertTokenizer, DistilBertModel - except ImportError: - log.warning("-" * 100) - log.warning( - "ATTENTION! To use DistilBert, please first install a recent version of transformers!" - ) - log.warning("-" * 100) - pass - self.tokenizer = DistilBertTokenizer.from_pretrained(bert_model_or_path) - self.model = DistilBertModel.from_pretrained( - pretrained_model_name_or_path=bert_model_or_path, - output_hidden_states=True, - ) + self.model = DistilBertModel.from_pretrained(bert_model_or_path) elif "albert" in bert_model_or_path: self.tokenizer = AlbertTokenizer.from_pretrained(bert_model_or_path) - self.model = AlbertModel.from_pretrained( - pretrained_model_name_or_path=bert_model_or_path, - output_hidden_states=True, - ) + self.model = AlbertModel.from_pretrained(bert_model_or_path) else: self.tokenizer = BertTokenizer.from_pretrained(bert_model_or_path) - self.model = BertModel.from_pretrained( - pretrained_model_name_or_path=bert_model_or_path, - output_hidden_states=True, - ) - self.layer_indexes = [int(x) for x in layers.split(",")] - self.pooling_operation = pooling_operation - self.use_scalar_mix = use_scalar_mix - self.name = str(bert_model_or_path) - self.static_embeddings = True + self.model = BertModel.from_pretrained(bert_model_or_path, + output_hidden_states=True) - class BertInputFeatures(object): - """Private helper class for holding BERT-formatted features""" - - def __init__( - self, - unique_id, - tokens, - input_ids, - input_mask, - input_type_ids, - token_subtoken_count, - ): - self.unique_id = unique_id - self.tokens = tokens - self.input_ids = input_ids - self.input_mask = input_mask - self.input_type_ids = input_type_ids - self.token_subtoken_count = token_subtoken_count + # when initializing, embeddings are in eval mode by default + self.model.eval() - def _convert_sentences_to_features( - self, sentences, max_sequence_length: int - ) -> [BertInputFeatures]: + self.layer_indexes = [int(x) for x in layers.split(",")] - max_sequence_length = max_sequence_length + 2 + self.pooling_operation = pooling_operation - features: List[BertEmbeddings.BertInputFeatures] = [] - for (sentence_index, sentence) in enumerate(sentences): + self.use_scalar_mix = use_scalar_mix - bert_tokenization: List[str] = [] - token_subtoken_count: Dict[int, int] = {} + self.name = str(bert_model_or_path) - for token in sentence: - subtokens = self.tokenizer.tokenize(token.text) - bert_tokenization.extend(subtokens) - token_subtoken_count[token.idx] = len(subtokens) - - if len(bert_tokenization) > max_sequence_length - 2: - bert_tokenization = bert_tokenization[0 : (max_sequence_length - 2)] - - tokens = [] - input_type_ids = [] - tokens.append("[CLS]") - input_type_ids.append(0) - for token in bert_tokenization: - tokens.append(token) - input_type_ids.append(0) - tokens.append("[SEP]") - input_type_ids.append(0) - - input_ids = self.tokenizer.convert_tokens_to_ids(tokens) - # The mask has 1 for real tokens and 0 for padding tokens. Only real - # tokens are attended to. - input_mask = [1] * len(input_ids) - - # Zero-pad up to the sequence length. - while len(input_ids) < max_sequence_length: - input_ids.append(0) - input_mask.append(0) - input_type_ids.append(0) - - features.append( - BertEmbeddings.BertInputFeatures( - unique_id=sentence_index, - tokens=tokens, - input_ids=input_ids, - input_mask=input_mask, - input_type_ids=input_type_ids, - token_subtoken_count=token_subtoken_count, - ) - ) + self.fine_tune = fine_tune + self.static_embeddings = not self.fine_tune - return features + self.document_only = document_only def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: """Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings are non-static.""" + # prepare id maps for BERT model + tokenized = [] + for sentence in sentences: + tokenized.append(torch.tensor(self.tokenizer.encode(sentence.to_tokenized_string(), add_special_tokens=True), dtype=torch.long)) + # first, find longest sentence in batch - longest_sentence_in_batch: int = len( - max( - [ - self.tokenizer.tokenize(sentence.to_tokenized_string()) - for sentence in sentences - ], - key=len, - ) - ) + longest_sequence_in_batch: int = len(max(tokenized, key=len)) - # prepare id maps for BERT model - features = self._convert_sentences_to_features( - sentences, longest_sentence_in_batch - ) - all_input_ids = torch.LongTensor([f.input_ids for f in features]).to( - flair.device + # initialize batch tensors + input_ids = torch.zeros( + [len(sentences), longest_sequence_in_batch], + dtype=torch.long, + device=flair.device, ) - all_input_masks = torch.LongTensor([f.input_mask for f in features]).to( - flair.device + mask = torch.zeros( + [len(sentences), longest_sequence_in_batch], + dtype=torch.long, + device=flair.device, ) - # put encoded batch through BERT model to get all hidden states of all encoder layers - self.model.to(flair.device) - self.model.eval() - all_encoder_layers = self.model(all_input_ids, attention_mask=all_input_masks)[ - -1 - ] + for s_id, sentence in enumerate(tokenized): + sequence_length = len(sentence) + input_ids[s_id][:sequence_length] = sentence + mask[s_id][:sequence_length] = torch.ones(sequence_length) + - with torch.no_grad(): + # put encoded batch through BERT model to get all hidden states of all encoder layers + all_encoder_layers = self.model(input_ids, attention_mask=mask)[-1] - for sentence_index, sentence in enumerate(sentences): + # gradients are enable if fine-tuning is enabled + gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() - feature = features[sentence_index] + with gradient_context: + for sentence_idx, (subtoken_ids, sentence) in enumerate(zip(tokenized, sentences)): + subtokens = self.tokenizer.convert_ids_to_tokens(subtoken_ids) + + token_subtoken_embeddings = [] + token_idx = 0 + for subtoken_idx, subtoken in enumerate(subtokens): + if subtoken == '[SEP]': + continue - # get aggregated embeddings for each BERT-subtoken in sentence - subtoken_embeddings = [] - for token_index, _ in enumerate(feature.tokens): - all_layers = [] - for layer_index in self.layer_indexes: - if self.use_scalar_mix: - layer_output = all_encoder_layers[int(layer_index)][ - sentence_index - ] + # check if this token initiates a new word, internal subtokens start with ## + if not subtoken.startswith('##') and len(token_subtoken_embeddings) > 0: + if token_idx == 0: + sentence.set_embedding(self.name, token_subtoken_embeddings[0]) else: - layer_output = all_encoder_layers[int(layer_index)][ - sentence_index - ] - all_layers.append(layer_output[token_index]) + token: Token = sentence[token_idx - 1] + if self.pooling_operation == "first": + token_embedding = token_subtoken_embeddings[0] + else: + token_embedding = torch.mean(torch.stack(token_subtoken_embeddings, dim=0), dim=0) + token.set_embedding(self.name, token_embedding) + # clear subtoken embedding and go to next token + token_subtoken_embeddings = [] + token_idx += 1 + + if token_idx >= 1 and self.document_only: + break + + # get embedding of subtoken + all_layers = [all_encoder_layers[int(layer_index)][sentence_idx, subtoken_idx] + for layer_index in self.layer_indexes] if self.use_scalar_mix: sm = ScalarMix(mixture_size=len(all_layers)) - sm_embeddings = sm(all_layers) - all_layers = [sm_embeddings] - - subtoken_embeddings.append(torch.cat(all_layers)) - - # get the current sentence object - token_idx = 0 - for token in sentence: - # add concatenated embedding to sentence - token_idx += 1 - - if self.pooling_operation == "first": - # use first subword embedding if pooling operation is 'first' - token.set_embedding(self.name, subtoken_embeddings[token_idx]) + embedding = sm(all_layers) else: - # otherwise, do a mean over all subwords in token - embeddings = subtoken_embeddings[ - token_idx : token_idx - + feature.token_subtoken_count[token.idx] - ] - embeddings = [ - embedding.unsqueeze(0) for embedding in embeddings - ] - mean = torch.mean(torch.cat(embeddings, dim=0), dim=0) - token.set_embedding(self.name, mean) - - token_idx += feature.token_subtoken_count[token.idx] - 1 + embedding = torch.cat(all_layers, dim=0) + + # add to list of embeddings + token_subtoken_embeddings.append(embedding) return sentences