Skip to content

Commit

Permalink
GH-217: Add possibility to resume training for language model.
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma committed Nov 28, 2018
1 parent dea5dfd commit 553e1d7
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 19 deletions.
70 changes: 55 additions & 15 deletions flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import math
from torch.autograd import Variable
from typing import List

from torch.optim import Optimizer

from flair.data import Dictionary


Expand All @@ -18,8 +21,7 @@ def __init__(self,
nlayers: int,
embedding_size: int = 100,
nout=None,
dropout=0.5,
best_score=None):
dropout=0.5):

super(LanguageModel, self).__init__()

Expand Down Expand Up @@ -52,8 +54,6 @@ def __init__(self,

self.init_weights()

self.best_score = best_score

# auto-spawn on GPU if available
if torch.cuda.is_available():
self.cuda()
Expand Down Expand Up @@ -128,23 +128,46 @@ def load_language_model(cls, model_file: Path):
else:
state = torch.load(str(model_file))

best_score = state['best_score'] if 'best_score' in state else None

model = LanguageModel(state['dictionary'],
state['is_forward_lm'],
state['hidden_size'],
state['nlayers'],
state['embedding_size'],
state['nout'],
state['dropout'],
best_score)
state['is_forward_lm'],
state['hidden_size'],
state['nlayers'],
state['embedding_size'],
state['nout'],
state['dropout'])
model.load_state_dict(state['state_dict'])
model.eval()
if torch.cuda.is_available():
model.cuda()

return model

def save(self, file: Path):
@classmethod
def load_checkpoint(cls, model_file: Path):
if not torch.cuda.is_available():
state = torch.load(str(model_file), map_location='cpu')
else:
state = torch.load(str(model_file))

epoch = state['epoch'] if 'epoch' in state else None
loss = state['loss'] if 'loss' in state else None
optimizer_state_dict = state['optimizer_state_dict'] if 'optimizer_state_dict' in state else None

model = LanguageModel(state['dictionary'],
state['is_forward_lm'],
state['hidden_size'],
state['nlayers'],
state['embedding_size'],
state['nout'],
state['dropout'])
model.load_state_dict(state['state_dict'])
model.eval()
if torch.cuda.is_available():
model.cuda()

return {'model': model, 'epoch': epoch, 'loss': loss, 'optimizer_state_dict': optimizer_state_dict}

def save_checkpoint(self, file: Path, optimizer: Optimizer, epoch: int, loss: float):
model_state = {
'state_dict': self.state_dict(),
'dictionary': self.dictionary,
Expand All @@ -154,8 +177,25 @@ def save(self, file: Path):
'embedding_size': self.embedding_size,
'nout': self.nout,
'dropout': self.dropout,
'best_score': self.best_score
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss
}

torch.save(model_state, str(file), pickle_protocol=4)

def save(self, file: Path):
model_state = {
'state_dict': self.state_dict(),
'dictionary': self.dictionary,
'is_forward_lm': self.is_forward_lm,
'hidden_size': self.hidden_size,
'nlayers': self.nlayers,
'embedding_size': self.embedding_size,
'nout': self.nout,
'dropout': self.dropout
}

torch.save(model_state, str(file), pickle_protocol=4)

def generate_text(self, number_of_characters=1000) -> str:
Expand Down
26 changes: 22 additions & 4 deletions flair/trainers/language_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,24 @@ def tokenize(self, path: Path):


class LanguageModelTrainer:
def __init__(self, model: LanguageModel, corpus: TextCorpus, optimizer: Optimizer = SGD, test_mode: bool = False):

def __init__(self,
model: LanguageModel,
corpus: TextCorpus,
optimizer: Optimizer = SGD,
test_mode: bool = False,
epoch: int = 0,
loss: float = 1
):
self.model: LanguageModel = model
self.optimzer: Optimizer = optimizer
self.corpus: TextCorpus = corpus
self.test_mode: bool = test_mode

self.loss_function = torch.nn.CrossEntropyLoss()
self.log_interval = 100
self.epoch = epoch
self.loss = loss

def train(self,
base_path: Path,
Expand All @@ -171,6 +181,7 @@ def train(self,
patience: int = 10,
clip=0.25,
max_epochs: int = 1000,
checkpoint: bool = False,
**kwargs):

number_of_splits: int = len(self.corpus.train_files)
Expand All @@ -186,9 +197,10 @@ def train(self,

try:

epoch = 0
best_val_loss = self.model.best_score if self.model.best_score is not None else 100000000
epoch = self.epoch
best_val_loss = self.loss
optimizer = self.optimzer(self.model.parameters(), lr=learning_rate, **kwargs)

if isinstance(optimizer, (AdamW, SGDW)):
scheduler: ReduceLRWDOnPlateau = ReduceLRWDOnPlateau(optimizer, verbose=True,
factor=anneal_factor,
Expand All @@ -202,6 +214,8 @@ def train(self,

# after pass over all splits, increment epoch count
if (split - 1) % number_of_splits == 0:
if checkpoint:
self.model.save_checkpoint(base_path / 'checkpoint.pt', optimizer, epoch, best_val_loss)
epoch += 1

log.info('Split %d' % split + '\t - ({:%H:%M:%S})'.format(datetime.datetime.now()))
Expand Down Expand Up @@ -363,4 +377,8 @@ def _repackage_hidden(h):
"""Wraps hidden states in new Variables, to detach them from their history."""
return tuple(Variable(v) for v in h)


@staticmethod
def load_from_checkpoint(checkpoint_file: Path, corpus: TextCorpus, optimizer: Optimizer = SGD):
checkpoint = LanguageModel.load_checkpoint(checkpoint_file)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return LanguageModelTrainer(checkpoint['model'], corpus, optimizer, epoch=checkpoint['epoch'], loss=checkpoint['loss'])

0 comments on commit 553e1d7

Please sign in to comment.