Skip to content

Commit

Permalink
GH-17: GPU optimizations for LM embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik authored and tabergma committed Jul 31, 2018
1 parent 3c7ebae commit ee30266
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 86 deletions.
5 changes: 0 additions & 5 deletions flair/data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from typing import List, Dict
import torch
import random
from random import randint
import os
from os import listdir
from os.path import isfile, join
from collections import Counter
from collections import defaultdict

Expand Down
60 changes: 19 additions & 41 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle
import re
from abc import abstractmethod
from typing import List
from typing import List, Union

import gensim
import numpy as np
Expand All @@ -26,7 +26,7 @@ def embedding_length(self) -> int:
def embedding_type(self) -> str:
pass

def embed(self, sentences: List[Sentence]) -> List[Sentence]:
def embed(self, sentences: Union[Sentence, List[Sentence]]) -> List[Sentence]:
"""Add embeddings to all words in a list of sentences. If embeddings are already added, updates only if embeddings
are non-static."""

Expand Down Expand Up @@ -208,7 +208,11 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
else:
word_embedding = np.zeros(self.embedding_length, dtype='float')

word_embedding = torch.autograd.Variable(torch.FloatTensor(word_embedding))
# if torch.cuda.is_available():
# word_embedding = torch.cuda.FloatTensor(word_embedding)
# else:
word_embedding = torch.FloatTensor(word_embedding)

token.set_embedding(self.name, word_embedding)

return sentences
Expand All @@ -224,20 +228,11 @@ def __init__(self, path_to_char_dict: str = None):
self.name = 'Char'
self.static_embeddings = False

# get list of common characters if none provided
# use list of common characters if none provided
if path_to_char_dict is None:
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models/common_characters'
char_dict = cached_path(base_path, cache_dir='datasets')

# load dictionary
self.char_dictionary: Dictionary = Dictionary()
with open(char_dict, 'rb') as f:
mappings = pickle.load(f, encoding='latin1')
idx2item = mappings['idx2item']
item2idx = mappings['item2idx']
self.char_dictionary.item2idx = item2idx
self.char_dictionary.idx2item = idx2item
# print(self.char_dictionary.item2idx)
self.char_dictionary: Dictionary = Dictionary.load('common-chars')
else:
self.char_dictionary: Dictionary = Dictionary.load_from_file(path_to_char_dict)

self.char_embedding_dim: int = 25
self.hidden_size_char: int = 25
Expand All @@ -260,7 +255,6 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
# translate words in sentence into ints using dictionary
for token in sentence.tokens:
token: Token = token
# print(token)
char_indices = [self.char_dictionary.get_idx_for_item(char) for char in token.text]
tokens_char_indices.append(char_indices)

Expand All @@ -278,7 +272,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
for i, c in enumerate(tokens_sorted_by_length):
tokens_mask[i, :chars2_length[i]] = c

tokens_mask = torch.autograd.Variable(torch.LongTensor(tokens_mask))
tokens_mask = torch.LongTensor(tokens_mask)

# chars for rnn processing
chars = tokens_mask
Expand All @@ -293,8 +287,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):

outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out)
outputs = outputs.transpose(0, 1)
chars_embeds_temp = torch.autograd.Variable(
torch.FloatTensor(torch.zeros((outputs.size(0), outputs.size(2)))))
chars_embeds_temp = torch.FloatTensor(torch.zeros((outputs.size(0), outputs.size(2))))
if torch.cuda.is_available():
chars_embeds_temp = chars_embeds_temp.cuda()
for i, index in enumerate(output_lengths):
Expand All @@ -304,7 +297,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
character_embeddings[d[i]] = chars_embeds_temp[i]

for token_number, token in enumerate(sentence.tokens):
token.set_embedding(self.name, character_embeddings[token_number].cpu())
token.set_embedding(self.name, character_embeddings[token_number])


class CharLMEmbeddings(TokenEmbeddings):
Expand Down Expand Up @@ -359,12 +352,8 @@ def __init__(self, model, detach: bool = True):
self.name = model
self.static_embeddings = detach

import flair.models
self.lm: flair.models.LanguageModel = flair.models.LanguageModel.load_language_model(model)
if torch.cuda.is_available():
self.lm = self.lm.cuda()
self.lm.eval()

from flair.models import LanguageModel
self.lm = LanguageModel.load_language_model(model)
self.detach = detach

self.is_forward_lm: bool = self.lm.is_forward_lm
Expand All @@ -378,7 +367,7 @@ def __init__(self, model, detach: bool = True):

dummy_sentence: Sentence = Sentence()
dummy_sentence.add_token(Token('hello'))
embedded_dummy = self.embed([dummy_sentence])
embedded_dummy = self.embed(dummy_sentence)
self.__embedding_length: int = len(embedded_dummy[0].get_token(1).get_embedding())

@property
Expand Down Expand Up @@ -406,8 +395,6 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
'\n' + sentence.to_plain_string()[::-1] + ' ' + (
(longest_character_sequence_in_batch - len(sentence.to_plain_string())) * ' '))

# print(sentences_padded)

# get states from LM
all_hidden_states_in_lm = self.lm.get_representation(sentences_padded, self.detach)

Expand All @@ -426,23 +413,20 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
else:
offset = offset_backward

embedding = all_hidden_states_in_lm[offset, i, :].data.cpu()
# if not torch.cuda.is_available():
# embedding = embedding.cpu()
embedding = all_hidden_states_in_lm[offset, i, :]

offset_forward += 1

offset_backward -= 1
offset_backward -= len(token.text)

token.set_embedding(self.name, torch.autograd.Variable(embedding))
token.set_embedding(self.name, embedding.cpu())
self.__embedding_length = len(embedding)

return sentences


class DocumentMeanEmbeddings(DocumentEmbeddings):

def __init__(self, word_embeddings: List[TokenEmbeddings], reproject_words: bool = True):
"""The constructor takes a list of embeddings to be combined."""
super().__init__()
Expand Down Expand Up @@ -493,15 +477,13 @@ def embed(self, paragraphs: List[Sentence]):

mean_embedding = torch.mean(word_embeddings, 0)

# mean_embedding /= len(paragraph.tokens)
paragraph.set_embedding(self.name, mean_embedding)

def _add_embeddings_internal(self, sentences: List[Sentence]):
pass


class DocumentLSTMEmbeddings(DocumentEmbeddings):

def __init__(self, word_embeddings: List[TokenEmbeddings], hidden_states=128, num_layers=1,
reproject_words: bool = True, bidirectional: bool = True):
"""The constructor takes a list of embeddings to be combined.
Expand All @@ -514,7 +496,6 @@ def __init__(self, word_embeddings: List[TokenEmbeddings], hidden_states=128, nu
"""
super().__init__()

# self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=word_embeddings)
self.embeddings: List[TokenEmbeddings] = word_embeddings

self.reproject_words = reproject_words
Expand Down Expand Up @@ -649,6 +630,3 @@ def embed(self, sentences: List[Sentence]):

def _add_embeddings_internal(self, sentences: List[Sentence]):
pass



16 changes: 9 additions & 7 deletions flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,17 @@ def initialize(self, matrix):
@classmethod
def load_language_model(cls, model_file):
state = torch.load(model_file)
model = LanguageModel(state['dictionary'],
state['is_forward_lm'],
state['hidden_size'],
state['nlayers'],
state['embedding_size'],
state['nout'],
state['dropout'])
model: LanguageModel = LanguageModel(state['dictionary'],
state['is_forward_lm'],
state['hidden_size'],
state['nlayers'],
state['embedding_size'],
state['nout'],
state['dropout'])
model.load_state_dict(state['state_dict'])
model.eval()
if torch.cuda.is_available():
model.cuda()
return model

def save(self, file):
Expand Down
49 changes: 19 additions & 30 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self,
self.hidden_word = None

# self.dropout = nn.Dropout(0.5)
self.dropout = LockedDropout(0.5)
self.dropout: nn.Module = LockedDropout(0.5)

rnn_input_dim: int = self.embeddings.embedding_length

Expand All @@ -88,7 +88,7 @@ def __init__(self,
dropout=0.5,
bidirectional=True)

self.relu = nn.ReLU()
self.nonlinearity = nn.Tanh()

# final linear map to tag space
if self.use_rnn:
Expand Down Expand Up @@ -149,16 +149,14 @@ def forward(self, sentences: List[Sentence]) -> Tuple[List, List]:
longest_token_sequence_in_batch: int = len(sentences[0])

self.embeddings.embed(sentences)
sent = sentences[0]
# print(sent)
# print(sent.tokens[0].get_embedding()[0:7])

all_sentence_tensors = []
lengths: List[int] = []
tag_list: List = []

# go through each sentence in batch
for i, sentence in enumerate(sentences):
padding = torch.FloatTensor(np.zeros(self.embeddings.embedding_length, dtype='float')).unsqueeze(0)

for sentence in sentences:

# get the tags in this sentence
tag_idx: List[int] = []
Expand All @@ -167,58 +165,50 @@ def forward(self, sentences: List[Sentence]) -> Tuple[List, List]:

word_embeddings = []

for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
token: Token = token

for token in sentence:
# get the tag
tag_idx.append(self.tag_dictionary.get_idx_for_item(token.get_tag(self.tag_type)))

# get the word embeddings
word_embeddings.append(token.get_embedding().unsqueeze(0))

# PADDING: pad shorter sentences out
# pad shorter sentences out
for add in range(longest_token_sequence_in_batch - len(sentence.tokens)):
word_embeddings.append(
torch.autograd.Variable(
torch.FloatTensor(np.zeros(self.embeddings.embedding_length, dtype='float')).unsqueeze(0)))
word_embeddings.append(padding)

word_embeddings_tensor = torch.cat(word_embeddings, 0)

sentence_states = word_embeddings_tensor

if torch.cuda.is_available():
tag_list.append(torch.cuda.LongTensor(tag_idx))
else:
tag_list.append(torch.LongTensor(tag_idx))

# ADD TO SENTENCE LIST: add the representation
all_sentence_tensors.append(sentence_states.unsqueeze(1))
all_sentence_tensors.append(word_embeddings_tensor.unsqueeze(1))

# --------------------------------------------------------------------
# GET REPRESENTATION FOR ENTIRE BATCH
# --------------------------------------------------------------------
# padded tensor for entire batch
sentence_tensor = torch.cat(all_sentence_tensors, 1)

if torch.cuda.is_available():
sentence_tensor = sentence_tensor.cuda()

# --------------------------------------------------------------------
# FF PART
# --------------------------------------------------------------------
tagger_states = self.dropout(sentence_tensor)
sentence_tensor = self.dropout(sentence_tensor)

if self.relearn_embeddings:
tagger_states = self.embedding2nn(tagger_states)
sentence_tensor = self.embedding2nn(sentence_tensor)

if self.use_rnn:
packed = torch.nn.utils.rnn.pack_padded_sequence(tagger_states, lengths)
packed = torch.nn.utils.rnn.pack_padded_sequence(sentence_tensor, lengths)

rnn_output, hidden = self.rnn(packed)

tagger_states, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(rnn_output)
sentence_tensor, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(rnn_output)

sentence_tensor = self.dropout(sentence_tensor)

tagger_states = self.dropout(tagger_states)
# sentence_tensor = self.nonlinearity(sentence_tensor)

features = self.linear(tagger_states)
features = self.linear(sentence_tensor)

predictions_list = []
for sentence_no, length in enumerate(lengths):
Expand All @@ -230,7 +220,6 @@ def forward(self, sentences: List[Sentence]) -> Tuple[List, List]:
return predictions_list, tag_list

def _score_sentence(self, feats, tags):
# print(tags)
# tags is ground_truth, a list of ints, length is len(sentence)
# feats is a 2D tensor, len(sentence) * tagset_size
r = torch.LongTensor(range(feats.size()[0]))
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
# initialize embeddings
embedding_types: List[TokenEmbeddings] = [

# WordEmbeddings('glove'),
WordEmbeddings('glove'),

# comment in this line to use character embeddings
CharacterEmbeddings(),
# CharacterEmbeddings(),

# comment in these lines to use contextual string embeddings
#
Expand All @@ -48,7 +48,7 @@
# initialize trainer
from flair.trainers.sequence_tagger_trainer import SequenceTaggerTrainer

trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=True)
trainer: SequenceTaggerTrainer = SequenceTaggerTrainer(tagger, corpus, test_mode=False)

trainer.train('resources/taggers/example-ner', mini_batch_size=32, max_epochs=20, save_model=False,
train_with_dev=False, anneal_mode=False)

0 comments on commit ee30266

Please sign in to comment.