Skip to content

Commit

Permalink
Merge pull request #1119 from zalandoresearch/GH-986-find-learning-rate
Browse files Browse the repository at this point in the history
GH-986: fix batch step in learning rate finder
  • Loading branch information
Alan Akbik authored Sep 17, 2019
2 parents 568b690 + 666eb07 commit 33be769
Showing 1 changed file with 48 additions and 37 deletions.
85 changes: 48 additions & 37 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,52 +564,63 @@ def find_learning_rate(

train_data = self.corpus.train

batch_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True)

scheduler = ExpAnnealLR(optimizer, end_learning_rate, iterations)

model_state = self.model.state_dict()
model_device = next(self.model.parameters()).device
self.model.train()

for itr, batch in enumerate(batch_loader):
loss = self.model.forward_loss(batch)

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
optimizer.step()
scheduler.step(1)
learning_rate = scheduler.get_lr()[0]

loss_item = loss.item()
if itr == 0:
best_loss = loss_item
else:
if smoothing_factor > 0:
moving_avg_loss = (
smoothing_factor * moving_avg_loss
+ (1 - smoothing_factor) * loss_item
)
loss_item = moving_avg_loss / (1 - smoothing_factor ** (itr + 1))
if loss_item < best_loss:
best_loss = loss
step = 0
while step < iterations:
batch_loader = DataLoader(
train_data, batch_size=mini_batch_size, shuffle=True
)
for batch in batch_loader:
step += 1

# forward pass
loss = self.model.forward_loss(batch)

# update optimizer and scheduler
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
optimizer.step()
scheduler.step(step)

print(scheduler.get_lr())
learning_rate = scheduler.get_lr()[0]

loss_item = loss.item()
if step == 1:
best_loss = loss_item
else:
if smoothing_factor > 0:
moving_avg_loss = (
smoothing_factor * moving_avg_loss
+ (1 - smoothing_factor) * loss_item
)
loss_item = moving_avg_loss / (
1 - smoothing_factor ** (step + 1)
)
if loss_item < best_loss:
best_loss = loss

if stop_early and (loss_item > 4 * best_loss or torch.isnan(loss)):
log_line(log)
log.info("loss diverged - stopping early!")
break
if step > iterations:
break

if itr > iterations:
break
if stop_early and (loss_item > 4 * best_loss or torch.isnan(loss)):
log_line(log)
log.info("loss diverged - stopping early!")
step = iterations
break

with open(str(learning_rate_tsv), "a") as f:
f.write(
f"{itr}\t{datetime.datetime.now():%H:%M:%S}\t{learning_rate}\t{loss_item}\n"
)
with open(str(learning_rate_tsv), "a") as f:
f.write(
f"{step}\t{datetime.datetime.now():%H:%M:%S}\t{learning_rate}\t{loss_item}\n"
)

self.model.load_state_dict(model_state)
self.model.to(model_device)
self.model.load_state_dict(model_state)
self.model.to(flair.device)

log_line(log)
log.info(f"learning rate finder finished - plot {learning_rate_tsv}")
Expand Down

0 comments on commit 33be769

Please sign in to comment.