diff --git a/flair/data.py b/flair/data.py index 6e8b704b9b..952350bb12 100644 --- a/flair/data.py +++ b/flair/data.py @@ -1,10 +1,5 @@ from typing import List, Dict import torch -import random -from random import randint -import os -from os import listdir -from os.path import isfile, join from collections import Counter from collections import defaultdict diff --git a/flair/embeddings.py b/flair/embeddings.py index 7595c4fafe..812bcfe24a 100644 --- a/flair/embeddings.py +++ b/flair/embeddings.py @@ -2,7 +2,7 @@ import pickle import re from abc import abstractmethod -from typing import List +from typing import List, Union import gensim import numpy as np @@ -26,7 +26,7 @@ def embedding_length(self) -> int: def embedding_type(self) -> str: pass - def embed(self, sentences: List[Sentence]) -> List[Sentence]: + def embed(self, sentences: Union[Sentence, List[Sentence]]) -> List[Sentence]: """Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings are non-static.""" @@ -208,7 +208,11 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: else: word_embedding = np.zeros(self.embedding_length, dtype='float') - word_embedding = torch.autograd.Variable(torch.FloatTensor(word_embedding)) + # if torch.cuda.is_available(): + # word_embedding = torch.cuda.FloatTensor(word_embedding) + # else: + word_embedding = torch.FloatTensor(word_embedding) + token.set_embedding(self.name, word_embedding) return sentences @@ -224,20 +228,11 @@ def __init__(self, path_to_char_dict: str = None): self.name = 'Char' self.static_embeddings = False - # get list of common characters if none provided + # use list of common characters if none provided if path_to_char_dict is None: - base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models/common_characters' - char_dict = cached_path(base_path, cache_dir='datasets') - - # load dictionary - self.char_dictionary: Dictionary = Dictionary() - with open(char_dict, 'rb') as f: - mappings = pickle.load(f, encoding='latin1') - idx2item = mappings['idx2item'] - item2idx = mappings['item2idx'] - self.char_dictionary.item2idx = item2idx - self.char_dictionary.idx2item = idx2item - # print(self.char_dictionary.item2idx) + self.char_dictionary: Dictionary = Dictionary.load('common-chars') + else: + self.char_dictionary: Dictionary = Dictionary.load_from_file(path_to_char_dict) self.char_embedding_dim: int = 25 self.hidden_size_char: int = 25 @@ -260,7 +255,6 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): # translate words in sentence into ints using dictionary for token in sentence.tokens: token: Token = token - # print(token) char_indices = [self.char_dictionary.get_idx_for_item(char) for char in token.text] tokens_char_indices.append(char_indices) @@ -278,7 +272,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): for i, c in enumerate(tokens_sorted_by_length): tokens_mask[i, :chars2_length[i]] = c - tokens_mask = torch.autograd.Variable(torch.LongTensor(tokens_mask)) + tokens_mask = torch.LongTensor(tokens_mask) # chars for rnn processing chars = tokens_mask @@ -293,8 +287,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out) outputs = outputs.transpose(0, 1) - chars_embeds_temp = torch.autograd.Variable( - torch.FloatTensor(torch.zeros((outputs.size(0), outputs.size(2))))) + chars_embeds_temp = torch.FloatTensor(torch.zeros((outputs.size(0), outputs.size(2)))) if torch.cuda.is_available(): chars_embeds_temp = chars_embeds_temp.cuda() for i, index in enumerate(output_lengths): @@ -304,7 +297,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): character_embeddings[d[i]] = chars_embeds_temp[i] for token_number, token in enumerate(sentence.tokens): - token.set_embedding(self.name, character_embeddings[token_number].cpu()) + token.set_embedding(self.name, character_embeddings[token_number]) class CharLMEmbeddings(TokenEmbeddings): @@ -359,12 +352,8 @@ def __init__(self, model, detach: bool = True): self.name = model self.static_embeddings = detach - import flair.models - self.lm: flair.models.LanguageModel = flair.models.LanguageModel.load_language_model(model) - if torch.cuda.is_available(): - self.lm = self.lm.cuda() - self.lm.eval() - + from flair.models import LanguageModel + self.lm = LanguageModel.load_language_model(model) self.detach = detach self.is_forward_lm: bool = self.lm.is_forward_lm @@ -378,7 +367,7 @@ def __init__(self, model, detach: bool = True): dummy_sentence: Sentence = Sentence() dummy_sentence.add_token(Token('hello')) - embedded_dummy = self.embed([dummy_sentence]) + embedded_dummy = self.embed(dummy_sentence) self.__embedding_length: int = len(embedded_dummy[0].get_token(1).get_embedding()) @property @@ -406,8 +395,6 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: '\n' + sentence.to_plain_string()[::-1] + ' ' + ( (longest_character_sequence_in_batch - len(sentence.to_plain_string())) * ' ')) - # print(sentences_padded) - # get states from LM all_hidden_states_in_lm = self.lm.get_representation(sentences_padded, self.detach) @@ -426,23 +413,20 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: else: offset = offset_backward - embedding = all_hidden_states_in_lm[offset, i, :].data.cpu() - # if not torch.cuda.is_available(): - # embedding = embedding.cpu() + embedding = all_hidden_states_in_lm[offset, i, :] offset_forward += 1 offset_backward -= 1 offset_backward -= len(token.text) - token.set_embedding(self.name, torch.autograd.Variable(embedding)) + token.set_embedding(self.name, embedding.cpu()) self.__embedding_length = len(embedding) return sentences class DocumentMeanEmbeddings(DocumentEmbeddings): - def __init__(self, word_embeddings: List[TokenEmbeddings], reproject_words: bool = True): """The constructor takes a list of embeddings to be combined.""" super().__init__() @@ -493,7 +477,6 @@ def embed(self, paragraphs: List[Sentence]): mean_embedding = torch.mean(word_embeddings, 0) - # mean_embedding /= len(paragraph.tokens) paragraph.set_embedding(self.name, mean_embedding) def _add_embeddings_internal(self, sentences: List[Sentence]): @@ -501,7 +484,6 @@ def _add_embeddings_internal(self, sentences: List[Sentence]): class DocumentLSTMEmbeddings(DocumentEmbeddings): - def __init__(self, word_embeddings: List[TokenEmbeddings], hidden_states=128, num_layers=1, reproject_words: bool = True, bidirectional: bool = True): """The constructor takes a list of embeddings to be combined. @@ -514,7 +496,6 @@ def __init__(self, word_embeddings: List[TokenEmbeddings], hidden_states=128, nu """ super().__init__() - # self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=word_embeddings) self.embeddings: List[TokenEmbeddings] = word_embeddings self.reproject_words = reproject_words @@ -649,6 +630,3 @@ def embed(self, sentences: List[Sentence]): def _add_embeddings_internal(self, sentences: List[Sentence]): pass - - - diff --git a/flair/models/language_model.py b/flair/models/language_model.py index e5506b2dbb..ec1017d4f0 100644 --- a/flair/models/language_model.py +++ b/flair/models/language_model.py @@ -112,15 +112,17 @@ def initialize(self, matrix): @classmethod def load_language_model(cls, model_file): state = torch.load(model_file) - model = LanguageModel(state['dictionary'], - state['is_forward_lm'], - state['hidden_size'], - state['nlayers'], - state['embedding_size'], - state['nout'], - state['dropout']) + model: LanguageModel = LanguageModel(state['dictionary'], + state['is_forward_lm'], + state['hidden_size'], + state['nlayers'], + state['embedding_size'], + state['nout'], + state['dropout']) model.load_state_dict(state['state_dict']) model.eval() + if torch.cuda.is_available(): + model.cuda() return model def save(self, file): diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index b1caa00027..6d25fd8d03 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -65,7 +65,7 @@ def __init__(self, self.hidden_word = None # self.dropout = nn.Dropout(0.5) - self.dropout = LockedDropout(0.5) + self.dropout: nn.Module = LockedDropout(0.5) rnn_input_dim: int = self.embeddings.embedding_length @@ -88,7 +88,7 @@ def __init__(self, dropout=0.5, bidirectional=True) - self.relu = nn.ReLU() + self.nonlinearity = nn.Tanh() # final linear map to tag space if self.use_rnn: @@ -149,16 +149,14 @@ def forward(self, sentences: List[Sentence]) -> Tuple[List, List]: longest_token_sequence_in_batch: int = len(sentences[0]) self.embeddings.embed(sentences) - sent = sentences[0] - # print(sent) - # print(sent.tokens[0].get_embedding()[0:7]) all_sentence_tensors = [] lengths: List[int] = [] tag_list: List = [] - # go through each sentence in batch - for i, sentence in enumerate(sentences): + padding = torch.FloatTensor(np.zeros(self.embeddings.embedding_length, dtype='float')).unsqueeze(0) + + for sentence in sentences: # get the tags in this sentence tag_idx: List[int] = [] @@ -167,58 +165,50 @@ def forward(self, sentences: List[Sentence]) -> Tuple[List, List]: word_embeddings = [] - for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): - token: Token = token - + for token in sentence: # get the tag tag_idx.append(self.tag_dictionary.get_idx_for_item(token.get_tag(self.tag_type))) - + # get the word embeddings word_embeddings.append(token.get_embedding().unsqueeze(0)) - # PADDING: pad shorter sentences out + # pad shorter sentences out for add in range(longest_token_sequence_in_batch - len(sentence.tokens)): - word_embeddings.append( - torch.autograd.Variable( - torch.FloatTensor(np.zeros(self.embeddings.embedding_length, dtype='float')).unsqueeze(0))) + word_embeddings.append(padding) word_embeddings_tensor = torch.cat(word_embeddings, 0) - sentence_states = word_embeddings_tensor - if torch.cuda.is_available(): tag_list.append(torch.cuda.LongTensor(tag_idx)) else: tag_list.append(torch.LongTensor(tag_idx)) - # ADD TO SENTENCE LIST: add the representation - all_sentence_tensors.append(sentence_states.unsqueeze(1)) + all_sentence_tensors.append(word_embeddings_tensor.unsqueeze(1)) - # -------------------------------------------------------------------- - # GET REPRESENTATION FOR ENTIRE BATCH - # -------------------------------------------------------------------- + # padded tensor for entire batch sentence_tensor = torch.cat(all_sentence_tensors, 1) - if torch.cuda.is_available(): sentence_tensor = sentence_tensor.cuda() # -------------------------------------------------------------------- # FF PART # -------------------------------------------------------------------- - tagger_states = self.dropout(sentence_tensor) + sentence_tensor = self.dropout(sentence_tensor) if self.relearn_embeddings: - tagger_states = self.embedding2nn(tagger_states) + sentence_tensor = self.embedding2nn(sentence_tensor) if self.use_rnn: - packed = torch.nn.utils.rnn.pack_padded_sequence(tagger_states, lengths) + packed = torch.nn.utils.rnn.pack_padded_sequence(sentence_tensor, lengths) rnn_output, hidden = self.rnn(packed) - tagger_states, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(rnn_output) + sentence_tensor, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(rnn_output) + + sentence_tensor = self.dropout(sentence_tensor) - tagger_states = self.dropout(tagger_states) + # sentence_tensor = self.nonlinearity(sentence_tensor) - features = self.linear(tagger_states) + features = self.linear(sentence_tensor) predictions_list = [] for sentence_no, length in enumerate(lengths): @@ -230,7 +220,6 @@ def forward(self, sentences: List[Sentence]) -> Tuple[List, List]: return predictions_list, tag_list def _score_sentence(self, feats, tags): - # print(tags) # tags is ground_truth, a list of ints, length is len(sentence) # feats is a 2D tensor, len(sentence) * tagset_size r = torch.LongTensor(range(feats.size()[0])) diff --git a/train.py b/train.py index e3bc40977e..69579f73ce 100644 --- a/train.py +++ b/train.py @@ -20,10 +20,10 @@ # initialize embeddings embedding_types: List[TokenEmbeddings] = [ - # WordEmbeddings('glove'), + WordEmbeddings('glove'), # comment in this line to use character embeddings - CharacterEmbeddings(), + # CharacterEmbeddings(), # comment in these lines to use contextual string embeddings # @@ -48,7 +48,7 @@ # initialize trainer from flair.trainers.sequence_tagger_trainer import SequenceTaggerTrainer -trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True) +trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False) trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=20, save_model=False, train_with_dev=False, anneal_mode=False)