diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index ec257bf444f5c..f6deb9adf58d3 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -178,12 +178,14 @@ Under the hood, Lightning does the following (pseudocode): loss = training_step(batch) losses.append(loss.detach()) + # clear gradients + optimizer.zero_grad() + # backward loss.backward() - # apply and clear grads + # update parameters optimizer.step() - optimizer.zero_grad() Training epoch-level metrics @@ -212,12 +214,14 @@ Here's the pseudocode of what it does under the hood: # forward out = training_step(val_batch) + # clear gradients + optimizer.zero_grad() + # backward loss.backward() - # apply and clear grads + # update parameters optimizer.step() - optimizer.zero_grad() epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs])) @@ -247,12 +251,14 @@ The matching pseudocode is: # forward out = training_step(val_batch) + # clear gradients + optimizer.zero_grad() + # backward loss.backward() - # apply and clear grads + # update parameters optimizer.step() - optimizer.zero_grad() training_epoch_end(outs) @@ -946,9 +952,9 @@ When set to ``False``, Lightning does not automate the optimization process. Thi opt = self.optimizers(use_pl_optimizer=True) loss = ... + opt.zero_grad() self.manual_backward(loss) opt.step() - opt.zero_grad() This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research. @@ -1048,11 +1054,13 @@ This is the pseudocode to describe how all the hooks are called during a call to loss = out.loss + on_before_zero_grad() + optimizer_zero_grad() + backward() on_after_backward() + optimizer_step() - on_before_zero_grad() - optimizer_zero_grad() on_train_batch_end(out) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 6edf896ada01c..5ecd90569f9fa 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -75,12 +75,14 @@ Here's the pseudocode for what the trainer does under the hood (showing the trai # train step loss = training_step(batch) + # clear gradients + optimizer.zero_grad() + # backward loss.backward() - # apply and clear grads + # update parameters optimizer.step() - optimizer.zero_grad() losses.append(loss)