Skip to content

Commit

Permalink
GH-217: Add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
tabergma committed Nov 28, 2018
1 parent c2d70ab commit faaebfc
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 14 deletions.
4 changes: 2 additions & 2 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def save(self, model_file: Path):

torch.save(model_state, str(model_file), pickle_protocol=4)

def save_checkpoint(self, model_file: Path, optimizer: Optimizer, scheduler_state: dict, epoch: int, loss: float):
def save_checkpoint(self, model_file: Path, optimizer_state: dict, scheduler_state: dict, epoch: int, loss: float):
model_state = {
'state_dict': self.state_dict(),
'embeddings': self.embeddings,
Expand All @@ -176,7 +176,7 @@ def save_checkpoint(self, model_file: Path, optimizer: Optimizer, scheduler_stat
'rnn_layers': self.rnn_layers,
'use_word_dropout': self.use_word_dropout,
'use_locked_dropout': self.use_locked_dropout,
'optimizer_state_dict': optimizer.state_dict(),
'optimizer_state_dict': optimizer_state,
'scheduler_state_dict': scheduler_state,
'epoch': epoch,
'loss': loss
Expand Down
4 changes: 2 additions & 2 deletions flair/models/text_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def save(self, model_file: Path):
}
torch.save(model_state, str(model_file), pickle_protocol=4)

def save_checkpoint(self, model_file: Path, optimizer: Optimizer, scheduler_state: dict, epoch: int, loss: float):
def save_checkpoint(self, model_file: Path, optimizer_state: dict, scheduler_state: dict, epoch: int, loss: float):
"""
Saves the current model to the provided file.
:param model_file: the model file
Expand All @@ -89,7 +89,7 @@ def save_checkpoint(self, model_file: Path, optimizer: Optimizer, scheduler_stat
'document_embeddings': self.document_embeddings,
'label_dictionary': self.label_dictionary,
'multi_label': self.multi_label,
'optimizer_state_dict': optimizer.state_dict(),
'optimizer_state_dict': optimizer_state,
'scheduler_state_dict': scheduler_state,
'epoch': epoch,
'loss': loss
Expand Down
10 changes: 7 additions & 3 deletions flair/trainers/language_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def __init__(self,
optimizer: Optimizer = SGD,
test_mode: bool = False,
epoch: int = 0,
loss: float = 1
loss: float = 1,
optimizer_state: dict = None
):
self.model: LanguageModel = model
self.optimzer: Optimizer = optimizer
Expand All @@ -171,6 +172,7 @@ def __init__(self,
self.log_interval = 100
self.epoch = epoch
self.loss = loss
self.optimizer_state = optimizer_state

def train(self,
base_path: Path,
Expand Down Expand Up @@ -200,6 +202,8 @@ def train(self,
epoch = self.epoch
best_val_loss = self.loss
optimizer = self.optimzer(self.model.parameters(), lr=learning_rate, **kwargs)
if self.optimizer_state is not None:
optimizer.load_state_dict(self.optimizer_state)

if isinstance(optimizer, (AdamW, SGDW)):
scheduler: ReduceLRWDOnPlateau = ReduceLRWDOnPlateau(optimizer, verbose=True,
Expand Down Expand Up @@ -380,5 +384,5 @@ def _repackage_hidden(h):
@staticmethod
def load_from_checkpoint(checkpoint_file: Path, corpus: TextCorpus, optimizer: Optimizer = SGD):
checkpoint = LanguageModel.load_checkpoint(checkpoint_file)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return LanguageModelTrainer(checkpoint['model'], corpus, optimizer, epoch=checkpoint['epoch'], loss=checkpoint['loss'])
return LanguageModelTrainer(checkpoint['model'], corpus, optimizer, epoch=checkpoint['epoch'],
loss=checkpoint['loss'], optimizer_state=checkpoint['optimizer_state_dict'])
19 changes: 12 additions & 7 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self,
optimizer: Optimizer = SGD,
epoch:int = 0,
loss: float = 1.0,
optimizer_state: dict = None,
scheduler_state: dict = None
):
self.model: flair.nn.Model = model
Expand All @@ -32,6 +33,7 @@ def __init__(self,
self.epoch: int = epoch
self.loss: float = loss
self.scheduler_state: dict = scheduler_state
self.optimizer_state: dict = optimizer_state

def find_learning_rate(self,
base_path: Path,
Expand Down Expand Up @@ -135,6 +137,8 @@ def train(self,
weight_extractor = WeightExtractor(base_path)

optimizer = self.optimizer(self.model.parameters(), lr=learning_rate, **kwargs)
if self.optimizer_state is not None:
optimizer.load_state_dict(self.optimizer_state)

# annealing scheduler
anneal_mode = 'min' if train_with_dev else 'max'
Expand All @@ -146,7 +150,6 @@ def train(self,
scheduler = ReduceLROnPlateau(optimizer, factor=anneal_factor,
patience=patience, mode=anneal_mode,
verbose=True)

if self.scheduler_state is not None:
scheduler.load_state_dict(self.scheduler_state)

Expand Down Expand Up @@ -266,7 +269,9 @@ def train(self,

# if checkpoint is enable, save model at each epoch
if checkpoint and not param_selection_mode:
self.model.save_checkpoint(base_path / 'checkpoint.pt', optimizer, scheduler, epoch, current_loss)
self.model.save_checkpoint(base_path / 'checkpoint.pt',
optimizer.state_dict(), scheduler.state_dict(),
epoch + 1, current_loss)

# if we use dev data, remember best model based on dev evaluation score
if not train_with_dev and not param_selection_mode and current_score == scheduler.best:
Expand Down Expand Up @@ -475,12 +480,12 @@ def load_from_checkpoint(checkpoint_file: Path, model_type: str, corpus: Corpus,

if model_type == 'SequenceTagger':
checkpoint = SequenceTagger.load_checkpoint(checkpoint_file)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return ModelTrainer(checkpoint['model'], corpus, optimizer, epoch=checkpoint['epoch'],
loss=checkpoint['loss'], scheduler_state=checkpoint['scheduler_state_dict'])
loss=checkpoint['loss'], optimizer_state=checkpoint['optimizer_state_dict'],
scheduler_state=checkpoint['scheduler_state_dict'])

if model_type == 'TextClassifier':
checkpoint = SequenceTagger.load_checkpoint(checkpoint_file)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
checkpoint = TextClassifier.load_checkpoint(checkpoint_file)
return ModelTrainer(checkpoint['model'], corpus, optimizer, epoch=checkpoint['epoch'],
loss=checkpoint['loss'], scheduler_state=checkpoint['scheduler_state_dict'])
loss=checkpoint['loss'], optimizer_state=checkpoint['optimizer_state_dict'],
scheduler_state=checkpoint['scheduler_state_dict'])
71 changes: 71 additions & 0 deletions tests/test_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,74 @@ def test_train_load_use_tagger_multicorpus(results_base_path, tasks_base_path):

# clean up results directory
shutil.rmtree(results_base_path)


@pytest.mark.integration
def test_train_resume_text_classification_training(results_base_path, tasks_base_path):
corpus = NLPTaskDataFetcher.load_corpus(NLPTask.IMDB, base_path=tasks_base_path)
label_dict = corpus.make_label_dictionary()

embeddings: TokenEmbeddings = CharLMEmbeddings('news-forward-fast', use_cache=False)
document_embeddings: DocumentLSTMEmbeddings = DocumentLSTMEmbeddings([embeddings], 128, 1, False)

model = TextClassifier(document_embeddings, label_dict, False)

trainer = ModelTrainer(model, corpus)
trainer.train(results_base_path, max_epochs=2, test_mode=True, checkpoint=True)

trainer = ModelTrainer.load_from_checkpoint(results_base_path / 'checkpoint.pt', 'TextClassifier', corpus)
trainer.train(results_base_path, max_epochs=2, test_mode=True, checkpoint=True)

# clean up results directory
shutil.rmtree(results_base_path)


@pytest.mark.integration
def test_train_resume_sequence_tagging_training(results_base_path, tasks_base_path):
corpus = NLPTaskDataFetcher.load_corpora([NLPTask.FASHION, NLPTask.GERMEVAL], base_path=tasks_base_path)
tag_dictionary = corpus.make_tag_dictionary('ner')

embeddings = WordEmbeddings('glove')

model: SequenceTagger = SequenceTagger(hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type='ner',
use_crf=False)

trainer = ModelTrainer(model, corpus)
trainer.train(results_base_path, max_epochs=2, test_mode=True, checkpoint=True)

trainer = ModelTrainer.load_from_checkpoint(results_base_path / 'checkpoint.pt', 'SequenceTagger', corpus)
trainer.train(results_base_path, max_epochs=2, test_mode=True, checkpoint=True)

# clean up results directory
shutil.rmtree(results_base_path)


@pytest.mark.integration
def test_train_resume_language_model_training(resources_path, results_base_path, tasks_base_path):
# get default dictionary
dictionary: Dictionary = Dictionary.load('chars')

# init forward LM with 128 hidden states and 1 layer
language_model: LanguageModel = LanguageModel(dictionary, is_forward_lm=True, hidden_size=128, nlayers=1)

# get the example corpus and process at character level in forward direction
corpus: TextCorpus = TextCorpus(resources_path / 'corpora/lorem_ipsum',
dictionary,
language_model.is_forward_lm,
character_level=True)

# train the language model
trainer: LanguageModelTrainer = LanguageModelTrainer(language_model, corpus, test_mode=True)
trainer.train(results_base_path, sequence_length=10, mini_batch_size=10, max_epochs=2, checkpoint=True)

trainer = LanguageModelTrainer.load_from_checkpoint(results_base_path / 'checkpoint.pt', corpus)
trainer.train(results_base_path, sequence_length=10, mini_batch_size=10, max_epochs=2)

# clean up results directory
shutil.rmtree(results_base_path)



0 comments on commit faaebfc

Please sign in to comment.