diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 666acf4f1d..5bcdb603ed 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -564,52 +564,63 @@ def find_learning_rate( train_data = self.corpus.train - batch_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True) - scheduler = ExpAnnealLR(optimizer, end_learning_rate, iterations) model_state = self.model.state_dict() - model_device = next(self.model.parameters()).device self.model.train() - for itr, batch in enumerate(batch_loader): - loss = self.model.forward_loss(batch) - - optimizer.zero_grad() - loss.backward() - torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) - optimizer.step() - scheduler.step(1) - learning_rate = scheduler.get_lr()[0] - - loss_item = loss.item() - if itr == 0: - best_loss = loss_item - else: - if smoothing_factor > 0: - moving_avg_loss = ( - smoothing_factor * moving_avg_loss - + (1 - smoothing_factor) * loss_item - ) - loss_item = moving_avg_loss / (1 - smoothing_factor ** (itr + 1)) - if loss_item < best_loss: - best_loss = loss + step = 0 + while step < iterations: + batch_loader = DataLoader( + train_data, batch_size=mini_batch_size, shuffle=True + ) + for batch in batch_loader: + step += 1 + + # forward pass + loss = self.model.forward_loss(batch) + + # update optimizer and scheduler + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) + optimizer.step() + scheduler.step(step) + + print(scheduler.get_lr()) + learning_rate = scheduler.get_lr()[0] + + loss_item = loss.item() + if step == 1: + best_loss = loss_item + else: + if smoothing_factor > 0: + moving_avg_loss = ( + smoothing_factor * moving_avg_loss + + (1 - smoothing_factor) * loss_item + ) + loss_item = moving_avg_loss / ( + 1 - smoothing_factor ** (step + 1) + ) + if loss_item < best_loss: + best_loss = loss - if stop_early and (loss_item > 4 * best_loss or torch.isnan(loss)): - log_line(log) - log.info("loss diverged - stopping early!") - break + if step > iterations: + break - if itr > iterations: - break + if stop_early and (loss_item > 4 * best_loss or torch.isnan(loss)): + log_line(log) + log.info("loss diverged - stopping early!") + step = iterations + break - with open(str(learning_rate_tsv), "a") as f: - f.write( - f"{itr}\t{datetime.datetime.now():%H:%M:%S}\t{learning_rate}\t{loss_item}\n" - ) + with open(str(learning_rate_tsv), "a") as f: + f.write( + f"{step}\t{datetime.datetime.now():%H:%M:%S}\t{learning_rate}\t{loss_item}\n" + ) - self.model.load_state_dict(model_state) - self.model.to(model_device) + self.model.load_state_dict(model_state) + self.model.to(flair.device) log_line(log) log.info(f"learning rate finder finished - plot {learning_rate_tsv}")