Skip to content

Commit

Permalink
[doc] Update the order of zero_grad and backward (#6478)
Browse files Browse the repository at this point in the history
* Fix zero_grad in docs

* Fix zero_grad in docs
  • Loading branch information
akihironitta authored Mar 12, 2021
1 parent 518c7e4 commit 680e83a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
26 changes: 17 additions & 9 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,14 @@ Under the hood, Lightning does the following (pseudocode):
loss = training_step(batch)
losses.append(loss.detach())
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# apply and clear grads
# update parameters
optimizer.step()
optimizer.zero_grad()
Training epoch-level metrics
Expand Down Expand Up @@ -212,12 +214,14 @@ Here's the pseudocode of what it does under the hood:
# forward
out = training_step(val_batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# apply and clear grads
# update parameters
optimizer.step()
optimizer.zero_grad()
epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs]))
Expand Down Expand Up @@ -247,12 +251,14 @@ The matching pseudocode is:
# forward
out = training_step(val_batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# apply and clear grads
# update parameters
optimizer.step()
optimizer.zero_grad()
training_epoch_end(outs)
Expand Down Expand Up @@ -946,9 +952,9 @@ When set to ``False``, Lightning does not automate the optimization process. Thi
opt = self.optimizers(use_pl_optimizer=True)
loss = ...
opt.zero_grad()
self.manual_backward(loss)
opt.step()
opt.zero_grad()
This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research.

Expand Down Expand Up @@ -1048,11 +1054,13 @@ This is the pseudocode to describe how all the hooks are called during a call to
loss = out.loss
on_before_zero_grad()
optimizer_zero_grad()
backward()
on_after_backward()
optimizer_step()
on_before_zero_grad()
optimizer_zero_grad()
on_train_batch_end(out)
Expand Down
6 changes: 4 additions & 2 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ Here's the pseudocode for what the trainer does under the hood (showing the trai
# train step
loss = training_step(batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# apply and clear grads
# update parameters
optimizer.step()
optimizer.zero_grad()
losses.append(loss)
Expand Down

0 comments on commit 680e83a

Please sign in to comment.