Skip to content

Commit

Permalink
added trainer for language models | save/restore function for char di…
Browse files Browse the repository at this point in the history
…ctionaries | updated stored language models
  • Loading branch information
aakbik authored and tabergma committed Jul 31, 2018
1 parent 3ae403c commit 3c7ebae
Show file tree
Hide file tree
Showing 6 changed files with 408 additions and 182 deletions.
171 changes: 32 additions & 139 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from segtok.tokenizer import word_tokenizer



class Dictionary:
"""
This class holds a dictionary that maps strings to IDs, used to generate one-hot encodings of strings.
Expand Down Expand Up @@ -63,6 +64,37 @@ def __len__(self) -> int:
def get_item_for_index(self, idx):
return self.idx2item[idx].decode('UTF-8')

def save(self, savefile):
import pickle
with open(savefile, 'wb') as f:
mappings = {
'idx2item': self.idx2item,
'item2idx': self.item2idx
}
pickle.dump(mappings, f)

@classmethod
def load_from_file(cls, filename: str):
import pickle
dictionary: Dictionary = Dictionary()
with open(filename, 'rb') as f:
mappings = pickle.load(f, encoding='latin1')
idx2item = mappings['idx2item']
item2idx = mappings['item2idx']
dictionary.item2idx = item2idx
dictionary.idx2item = idx2item
return dictionary

@classmethod
def load(cls, name: str):
from flair.file_utils import cached_path
if name == 'chars' or name == 'common-chars':
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models/common_characters'
char_dict = cached_path(base_path, cache_dir='datasets')
return Dictionary.load_from_file(char_dict)

return Dictionary.load_from_file(name)


class Token:
"""
Expand Down Expand Up @@ -401,145 +433,6 @@ def __str__(self) -> str:
return 'TaggedCorpus: %d train + %d dev + %d test sentences' % (len(self.train), len(self.dev), len(self.test))


class CorpusLM(object):
def __init__(self, path, dictionary: Dictionary, forward: bool = True, character_level: bool = True):
self.dictionary: Dictionary = dictionary
self.train_path = os.path.join(path, 'train')
self.train = None
self.forward = forward
self.split_on_char = character_level

self.train_files = sorted([f for f in listdir(self.train_path) if isfile(join(self.train_path, f))])
self.current_train_file = None

if forward:
self.valid = self.charsplit(os.path.join(path, 'valid.txt'), expand_vocab=False, forward=True,
split_on_char=self.split_on_char)
self.test = self.charsplit(os.path.join(path, 'test.txt'), expand_vocab=False, forward=True,
split_on_char=self.split_on_char)
else:
self.valid = self.charsplit(os.path.join(path, 'valid.txt'), expand_vocab=False, forward=False,
split_on_char=self.split_on_char)
self.test = self.charsplit(os.path.join(path, 'test.txt'), expand_vocab=False, forward=False,
split_on_char=self.split_on_char)

def get_next_train_slice(self) -> str:

if self.current_train_file == None:
self.current_train_file = self.train_files[0]

elif len(self.train_files) != 1:

index = self.train_files.index(self.current_train_file) + 1
if index > len(self.train_files): index = 0

self.current_train_file = self.train_files[index]

self.train = self.charsplit(os.path.join(self.train_path, self.current_train_file), expand_vocab=False,
forward=self.forward, split_on_char=self.split_on_char)

return self.current_train_file

def get_random_train_slice(self) -> str:
train_files = [f for f in listdir(self.train_path) if isfile(join(self.train_path, f))]
current_train_file = random.choice(train_files)
self.train = self.charsplit(os.path.join(self.train_path, current_train_file), expand_vocab=False,
forward=self.forward, split_on_char=self.split_on_char)
return current_train_file

def charsplit(self, path: str, expand_vocab=False, forward=True, split_on_char=True) -> torch.LongTensor:

"""Tokenizes a text file on characted basis."""
assert os.path.exists(path)

#
with open(path, 'r', encoding="utf-8") as f:
tokens = 0
for line in f:

if split_on_char:
chars = list(line)
else:
chars = line.split()

# print(chars)
tokens += len(chars)

# Add chars to the dictionary
if expand_vocab:
for char in chars:
self.dictionary.add_item(char)

if forward:
# charsplit file content
with open(path, 'r', encoding="utf-8") as f:
ids = torch.LongTensor(tokens)
token = 0
for line in f:
line = self.random_casechange(line)

if split_on_char:
chars = list(line)
else:
chars = line.split()

for char in chars:
if token >= tokens: break
ids[token] = self.dictionary.get_idx_for_item(char)
token += 1
else:
# charsplit file content
with open(path, 'r', encoding="utf-8") as f:
ids = torch.LongTensor(tokens)
token = tokens - 1
for line in f:
line = self.random_casechange(line)

if split_on_char:
chars = list(line)
else:
chars = line.split()

for char in chars:
if token >= tokens: break
ids[token] = self.dictionary.get_idx_for_item(char)
token -= 1

return ids

def random_casechange(self, line: str) -> str:
no = randint(0, 99)
if no is 0:
line = line.lower()
if no is 1:
line = line.upper()
return line

def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)

# Tokenize file content
with open(path, 'r') as f:
ids = torch.LongTensor(tokens)
token = 0
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = self.dictionary.word2idx[word]
token += 1

return ids


def iob2(tags):
"""
Check that tags have a valid IOB format.
Expand Down
14 changes: 7 additions & 7 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import torch

import flair
from .data import Dictionary, Token, Sentence, TaggedCorpus
from .file_utils import cached_path

Expand Down Expand Up @@ -329,37 +328,38 @@ def __init__(self, model, detach: bool = True):

# news-english-forward
if model.lower() == 'news-forward':
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-forward.pt'
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-forward-v0.2rc.pt'
model = cached_path(base_path, cache_dir='embeddings')

# news-english-backward
if model.lower() == 'news-backward':
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-backward.pt'
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-backward-v0.2rc.pt'
model = cached_path(base_path, cache_dir='embeddings')

# mix-english-forward
if model.lower() == 'mix-forward':
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-english-forward.pt'
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-english-forward-v0.2rc.pt'
model = cached_path(base_path, cache_dir='embeddings')

# mix-english-backward
if model.lower() == 'mix-backward':
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-english-backward.pt'
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-english-backward-v0.2rc.pt'
model = cached_path(base_path, cache_dir='embeddings')

# mix-english-forward
if model.lower() == 'german-forward':
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-german-forward.pt'
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-german-forward-v0.2rc.pt'
model = cached_path(base_path, cache_dir='embeddings')

# mix-english-backward
if model.lower() == 'german-backward':
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-german-backward.pt'
base_path = 'https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-mix-german-backward-v0.2rc.pt'
model = cached_path(base_path, cache_dir='embeddings')

self.name = model
self.static_embeddings = detach

import flair.models
self.lm: flair.models.LanguageModel = flair.models.LanguageModel.load_language_model(model)
if torch.cuda.is_available():
self.lm = self.lm.cuda()
Expand Down
68 changes: 36 additions & 32 deletions flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,45 @@
class LanguageModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""

def __init__(self, rnn_type, ntoken, ninp, nhid, nout, nlayers, dropout=0.5):
def __init__(self,
dictionary: Dictionary,
is_forward_lm: bool,
hidden_size: int,
nlayers: int,
embedding_size: int = 100,
nout=None,
dropout=0.5):

super(LanguageModel, self).__init__()

self.dictionary = Dictionary()
self.is_forward_lm: bool = True
self.dictionary = dictionary
self.is_forward_lm: bool = is_forward_lm

self.dropout = dropout
self.hidden_size = hidden_size
self.embedding_size = embedding_size
self.nlayers = nlayers

self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
self.encoder = nn.Embedding(len(dictionary), embedding_size)

if nlayers == 1:
self.rnn = nn.LSTM(ninp, nhid, nlayers)
self.rnn = nn.LSTM(embedding_size, hidden_size, nlayers)
else:
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)

self.decoder = nn.Linear(nhid, ntoken)

self.init_weights()

self.rnn_type = rnn_type
self.nhid = nhid
self.ninp = ninp
self.nlayers = nlayers
self.rnn = nn.LSTM(embedding_size, hidden_size, nlayers, dropout=dropout)

self.hidden = None

self.nout = nout
if nout is not None:
self.proj = nn.Linear(nhid, nout)
self.proj = nn.Linear(hidden_size, nout)
self.initialize(self.proj.weight)
self.decoder = nn.Linear(nout, ntoken)
self.decoder = nn.Linear(nout, len(dictionary))
else:
self.proj = None
self.decoder = nn.Linear(hidden_size, len(dictionary))

self.init_weights()

def init_weights(self):
initrange = 0.1
Expand Down Expand Up @@ -70,11 +75,8 @@ def forward(self, input, hidden, ordered_sequence_lengths=None):

def init_hidden(self, bsz):
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
else:
return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())
return (Variable(weight.new(self.nlayers, bsz, self.hidden_size).zero_()),
Variable(weight.new(self.nlayers, bsz, self.hidden_size).zero_()))

def get_representation(self, strings: List[str], detach_from_lm=True):

Expand Down Expand Up @@ -110,24 +112,26 @@ def initialize(self, matrix):
@classmethod
def load_language_model(cls, model_file):
state = torch.load(model_file)
model = LanguageModel(state['rnn_type'], state['ntoken'], state['ninp'], state['nhid'], state['nout'],
state['nlayers'], state['dropout'])
model = LanguageModel(state['dictionary'],
state['is_forward_lm'],
state['hidden_size'],
state['nlayers'],
state['embedding_size'],
state['nout'],
state['dropout'])
model.load_state_dict(state['state_dict'])
model.is_forward_lm = state['is_forward_lm']
model.dictionary = state['char_dictionary_forward']
model.eval()
return model

def save(self, file):
model_state = {
'state_dict': self.state_dict(),
'dictionary': self.dictionary,
'is_forward_lm': self.is_forward_lm,
'char_dictionary_forward': self.dictionary,
'rnn_type': self.rnn_type,
'ntoken': len(self.dictionary),
'ninp': self.ninp,
'nhid': self.nhid,
'nout': self.proj,
'hidden_size': self.hidden_size,
'nlayers': self.nlayers,
'embedding_size': self.embedding_size,
'nout': self.nout,
'dropout': self.dropout
}
torch.save(model_state, file, pickle_protocol=4)
1 change: 0 additions & 1 deletion flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ def neg_log_likelihood(self, sentences: List[Sentence], tag_type: str):
for i in range(len(feats)):
sentence_feats = feats[i]
sentence_tags = tags[i]

forward_score = self._forward_alg(sentence_feats)
# calculate the score of the ground_truth, in CRF
gold_score = self._score_sentence(sentence_feats, sentence_tags)
Expand Down
Loading

0 comments on commit 3c7ebae

Please sign in to comment.