Skip to content

Commit

Permalink
Merge pull request #2146 from troeshust96/enhancement/save_model_each…
Browse files Browse the repository at this point in the history
…_k_epochs

GH-2145: Added new param to save model each k epochs
  • Loading branch information
alanakbik authored Mar 13, 2021
2 parents 6503e6d + c44c406 commit 00d107d
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def train(
eval_on_train_fraction=0.0,
eval_on_train_shuffle=False,
save_model_at_each_epoch=False,
save_model_epoch_step: int = None,
**kwargs,
) -> dict:
"""
Expand Down Expand Up @@ -127,6 +128,7 @@ def train(
:param eval_on_train_shuffle: if True the train data fraction is determined on the start of training
and kept fixed during training, otherwise it's sampled at beginning of each epoch
:param save_model_at_each_epoch: If True, at each epoch the thus far trained model will be saved
:param save_model_epoch_step: Each save_model_epoch_step'th epoch the thus far trained model will be saved
:param kwargs: Other arguments for the Optimizer
:return:
"""
Expand Down Expand Up @@ -278,6 +280,10 @@ def train(
sampler.set_dataset(train_data)
shuffle = False

if not isinstance(save_model_epoch_step, int) or save_model_epoch_step < 1:
log.warning(f'save_model_epoch_step should be positive integer, not {save_model_epoch_step}. It will be set to None')
save_model_epoch_step = None

dev_score_history = []
dev_loss_history = []
train_loss_history = []
Expand Down Expand Up @@ -599,7 +605,7 @@ def train(
self.model.save(base_path / "pre-best-model.pt")
self.model.load_state_dict(current_state_dict)

if save_model_at_each_epoch:
if save_model_at_each_epoch or save_model_epoch_step is not None and not self.epoch % save_model_epoch_step:
print("saving model of current epoch")
model_name = "model_epoch_" + str(self.epoch) + ".pt"
self.model.save(base_path / model_name)
Expand Down

0 comments on commit 00d107d

Please sign in to comment.