Skip to content

Commit

Permalink
GH-2146: collapse saving parameters into one
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Mar 13, 2021
1 parent bf79c80 commit 4d81634
Showing 1 changed file with 77 additions and 80 deletions.
157 changes: 77 additions & 80 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@

class ModelTrainer:
def __init__(
self,
model: flair.nn.Model,
corpus: Corpus,
optimizer: torch.optim.Optimizer = SGD,
epoch: int = 0,
use_tensorboard: bool = False,
self,
model: flair.nn.Model,
corpus: Corpus,
optimizer: torch.optim.Optimizer = SGD,
epoch: int = 0,
use_tensorboard: bool = False,
):
"""
Initialize a model trainer
Expand All @@ -61,40 +61,39 @@ def __init__(
self.use_tensorboard: bool = use_tensorboard

def train(
self,
base_path: Union[Path, str],
learning_rate: float = 0.1,
mini_batch_size: int = 32,
mini_batch_chunk_size: int = None,
max_epochs: int = 100,
scheduler = AnnealOnPlateau,
cycle_momentum: bool = False,
anneal_factor: float = 0.5,
patience: int = 3,
initial_extra_patience = 0,
min_learning_rate: float = 0.0001,
train_with_dev: bool = False,
train_with_test: bool = False,
monitor_train: bool = False,
monitor_test: bool = False,
embeddings_storage_mode: str = "cpu",
checkpoint: bool = False,
save_final_model: bool = True,
anneal_with_restarts: bool = False,
anneal_with_prestarts: bool = False,
batch_growth_annealing: bool = False,
shuffle: bool = True,
param_selection_mode: bool = False,
write_weights: bool = False,
num_workers: int = 6,
sampler=None,
use_amp: bool = False,
amp_opt_level: str = "O1",
eval_on_train_fraction=0.0,
eval_on_train_shuffle=False,
save_model_at_each_epoch=False,
save_model_epoch_step: int = None,
**kwargs,
self,
base_path: Union[Path, str],
learning_rate: float = 0.1,
mini_batch_size: int = 32,
mini_batch_chunk_size: int = None,
max_epochs: int = 100,
scheduler=AnnealOnPlateau,
cycle_momentum: bool = False,
anneal_factor: float = 0.5,
patience: int = 3,
initial_extra_patience=0,
min_learning_rate: float = 0.0001,
train_with_dev: bool = False,
train_with_test: bool = False,
monitor_train: bool = False,
monitor_test: bool = False,
embeddings_storage_mode: str = "cpu",
checkpoint: bool = False,
save_final_model: bool = True,
anneal_with_restarts: bool = False,
anneal_with_prestarts: bool = False,
batch_growth_annealing: bool = False,
shuffle: bool = True,
param_selection_mode: bool = False,
write_weights: bool = False,
num_workers: int = 6,
sampler=None,
use_amp: bool = False,
amp_opt_level: str = "O1",
eval_on_train_fraction=0.0,
eval_on_train_shuffle=False,
save_model_each_k_epochs: int = 0,
**kwargs,
) -> dict:
"""
Trains any class that implements the flair.nn.Model interface.
Expand Down Expand Up @@ -127,7 +126,8 @@ def train(
if 'dev' the size is determined from dev set size
:param eval_on_train_shuffle: if True the train data fraction is determined on the start of training
and kept fixed during training, otherwise it's sampled at beginning of each epoch
:param save_model_at_each_epoch: If True, at each epoch the thus far trained model will be saved
:param save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will
be saved each 5 epochs. Default is 0 which means no model saving.
:param save_model_epoch_step: Each save_model_epoch_step'th epoch the thus far trained model will be saved
:param kwargs: Other arguments for the Optimizer
:return:
Expand Down Expand Up @@ -236,17 +236,18 @@ def train(

# minimize training loss if training with dev data, else maximize dev score
anneal_mode = "min" if train_with_dev else "max"

if scheduler == OneCycleLR:
dataset_size = len(self.corpus.train)
if train_with_dev:
dataset_size += len(self.corpus.dev)
lr_scheduler = OneCycleLR(optimizer,
max_lr=learning_rate,
steps_per_epoch=dataset_size//mini_batch_size + 1,
epochs=max_epochs-self.epoch, # if we load a checkpoint, we have already trained for self.epoch
pct_start=0.0,
cycle_momentum=cycle_momentum)
max_lr=learning_rate,
steps_per_epoch=dataset_size // mini_batch_size + 1,
epochs=max_epochs - self.epoch,
# if we load a checkpoint, we have already trained for self.epoch
pct_start=0.0,
cycle_momentum=cycle_momentum)
else:
lr_scheduler = scheduler(
optimizer,
Expand All @@ -256,7 +257,7 @@ def train(
mode=anneal_mode,
verbose=True,
)

if (isinstance(lr_scheduler, OneCycleLR) and batch_growth_annealing):
raise ValueError("Batch growth with OneCycle policy is not implemented.")

Expand All @@ -280,10 +281,6 @@ def train(
sampler.set_dataset(train_data)
shuffle = False

if not isinstance(save_model_epoch_step, int) or save_model_epoch_step < 1:
log.warning(f'save_model_epoch_step should be positive integer, not {save_model_epoch_step}. It will be set to None')
save_model_epoch_step = None

dev_score_history = []
dev_loss_history = []
train_loss_history = []
Expand Down Expand Up @@ -321,9 +318,9 @@ def train(

# reload last best model if annealing with restarts is enabled
if (
(anneal_with_restarts or anneal_with_prestarts)
and learning_rate != previous_learning_rate
and (base_path / "best-model.pt").exists()
(anneal_with_restarts or anneal_with_prestarts)
and learning_rate != previous_learning_rate
and (base_path / "best-model.pt").exists()
):
if anneal_with_restarts:
log.info("resetting to best model")
Expand All @@ -348,7 +345,7 @@ def train(
batch_loader = DataLoader(
train_data,
batch_size=mini_batch_size,
shuffle=shuffle if self.epoch > 1 else False, # never shuffle the first epoch
shuffle=shuffle if self.epoch > 1 else False, # never shuffle the first epoch
num_workers=num_workers,
sampler=sampler,
)
Expand Down Expand Up @@ -376,7 +373,7 @@ def train(
batch_steps = [batch]
if len(batch) > micro_batch_size:
batch_steps = [
batch[x : x + micro_batch_size]
batch[x: x + micro_batch_size]
for x in range(0, len(batch), micro_batch_size)
]

Expand All @@ -396,15 +393,15 @@ def train(
# do the optimizer step
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
optimizer.step()

# do the scheduler step if one-cycle
if isinstance(lr_scheduler, OneCycleLR):
lr_scheduler.step()
# get new learning rate
for group in optimizer.param_groups:
learning_rate = group["lr"]
if "momentum" in group:
momentum = group["momentum"]
momentum = group["momentum"]

seen_batches += 1
train_loss += loss.item()
Expand Down Expand Up @@ -590,11 +587,11 @@ def train(

# if we use dev data, remember best model based on dev evaluation score
if (
(not train_with_dev or anneal_with_restarts or anneal_with_prestarts)
and not param_selection_mode
and not isinstance(lr_scheduler, OneCycleLR)
and current_score == lr_scheduler.best
and bad_epochs == 0
(not train_with_dev or anneal_with_restarts or anneal_with_prestarts)
and not param_selection_mode
and not isinstance(lr_scheduler, OneCycleLR)
and current_score == lr_scheduler.best
and bad_epochs == 0
):
print("saving best model")
self.model.save(base_path / "best-model.pt")
Expand All @@ -604,8 +601,8 @@ def train(
self.model.load_state_dict(last_epoch_model_state_dict)
self.model.save(base_path / "pre-best-model.pt")
self.model.load_state_dict(current_state_dict)
if save_model_at_each_epoch or save_model_epoch_step is not None and not self.epoch % save_model_epoch_step:

if save_model_each_k_epochs > 0 and not self.epoch % save_model_each_k_epochs:
print("saving model of current epoch")
model_name = "model_epoch_" + str(self.epoch) + ".pt"
self.model.save(base_path / model_name)
Expand Down Expand Up @@ -658,7 +655,7 @@ def load_checkpoint(cls, checkpoint: Union[Path, str], corpus: Corpus):
return model

def final_test(
self, base_path: Union[Path, str], eval_mini_batch_size: int, num_workers: int = 8
self, base_path: Union[Path, str], eval_mini_batch_size: int, num_workers: int = 8
):
if type(base_path) is str:
base_path = Path(base_path)
Expand Down Expand Up @@ -705,16 +702,16 @@ def final_test(
return final_score

def find_learning_rate(
self,
base_path: Union[Path, str],
file_name: str = "learning_rate.tsv",
start_learning_rate: float = 1e-7,
end_learning_rate: float = 10,
iterations: int = 100,
mini_batch_size: int = 32,
stop_early: bool = True,
smoothing_factor: float = 0.98,
**kwargs,
self,
base_path: Union[Path, str],
file_name: str = "learning_rate.tsv",
start_learning_rate: float = 1e-7,
end_learning_rate: float = 10,
iterations: int = 100,
mini_batch_size: int = 32,
stop_early: bool = True,
smoothing_factor: float = 0.98,
**kwargs,
) -> Path:
best_loss = None
moving_avg_loss = 0
Expand Down Expand Up @@ -765,11 +762,11 @@ def find_learning_rate(
else:
if smoothing_factor > 0:
moving_avg_loss = (
smoothing_factor * moving_avg_loss
+ (1 - smoothing_factor) * loss_item
smoothing_factor * moving_avg_loss
+ (1 - smoothing_factor) * loss_item
)
loss_item = moving_avg_loss / (
1 - smoothing_factor ** (step + 1)
1 - smoothing_factor ** (step + 1)
)
if loss_item < best_loss:
best_loss = loss
Expand Down

0 comments on commit 4d81634

Please sign in to comment.