Skip to content

Commit

Permalink
Add the on_before_optimizer_step hook (#8048)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
ddrevicky and carmocca authored Jul 9, 2021
1 parent 31fca16 commit 1b06edf
Show file tree
Hide file tree
Showing 15 changed files with 91 additions and 36 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))


- Added the `on_before_optimizer_step` hook ([#8048](https://github.com/PyTorchLightning/pytorch-lightning/pull/8048))


- Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867))


Expand Down Expand Up @@ -244,10 +247,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822))


- The `on_after_backward` hook is now called on accumulating iterations ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
- The `on_after_backward` hook is now called on accumulating iterations. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- The mixed precision loss is no longer unscaled before the `on_after_backward` hook ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
- The mixed precision loss is no longer unscaled before the `on_after_backward` hook. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- The `TrainingTypePlugin.{pre,post}_backward` hooks no longer take the `optimizer, opt_idx, should_accumulate` arguments ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
Expand Down
7 changes: 7 additions & 0 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,7 @@ for more information.
backward()
on_after_backward()
on_before_optimizer_step()
optimizer_step()
on_train_batch_end()
Expand Down Expand Up @@ -1451,6 +1452,12 @@ on_test_model_train
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train
:noindex:

on_before_optimizer_step
~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step
:noindex:

optimizer_step
~~~~~~~~~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,12 @@ on_after_backward
.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward
:noindex:

on_before_optimizer_step
^^^^^^^^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_before_optimizer_step
:noindex:

on_before_zero_grad
^^^^^^^^^^^^^^^^^^^

Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,13 @@ def on_before_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModu
pass

def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called after ``loss.backward()`` and before optimizers do anything."""
"""Called after ``loss.backward()`` and before optimizers are stepped."""
pass

def on_before_optimizer_step(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer, opt_idx: int
) -> None:
"""Called before ``optimizer.step()``."""
pass

def on_before_zero_grad(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer) -> None:
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
on_load_checkpoint: Optional[Callable] = None,
on_before_backward: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_optimizer_step: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
on_predict_start: Optional[Callable] = None,
on_predict_end: Optional[Callable] = None,
Expand Down
25 changes: 21 additions & 4 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,19 +306,36 @@ def on_before_backward(self, loss: torch.Tensor) -> None:

def on_after_backward(self) -> None:
"""
Called in the training loop after loss.backward() and before optimizers do anything.
This is the ideal place to inspect or log gradient information.
Called after ``loss.backward()`` and before optimizers are stepped.
Note:
If using native AMP, the gradients will not be unscaled at this point.
Use the ``on_before_optimizer_step`` if you need the unscaled gradients.
"""

def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""
Called before ``optimizer.step()``.
The hook is only called if gradients do not need to be accumulated.
See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`.
If using native AMP, the loss will be unscaled before calling this hook.
See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
for more information on the scaling of gradients.
Args:
optimizer: Current optimizer being used.
optimizer_idx: Index of the current optimizer being used.
Example::
def on_after_backward(self):
def on_before_optimizer_step(self, optimizer, optimizer_idx):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
for k, v in self.named_parameters():
self.logger.experiment.add_histogram(
tag=k, values=v.grad, global_step=self.trainer.global_step
)
"""

def on_post_move_to_device(self) -> None:
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,16 @@ def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Seq

def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""
always called before the optimizer step.
"""
# apex amp does not support closures.
lambda_closure()
"""Hook to do something before each optimizer step."""
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # APEX amp does not support closures
optimizer.step(**kwargs)
return False

Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ def __init__(self, precision: int) -> None:

def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
# DeepSpeed not support closures.
lambda_closure()
deepspeed_engine = pl_module.trainer.model
"""Hook to do something before each optimizer step."""
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # DeepSpeed does not support closures
deepspeed_engine = model.trainer.model
deepspeed_engine.step()
return False

Expand Down
19 changes: 9 additions & 10 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def pre_backward(

def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
Expand All @@ -58,16 +58,15 @@ def pre_optimizer_step(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
# TODO: Add `on_before_optimizer_step`
# self.scaler.unscale_(optimizer)
# pl_module.trainer.call_hook("on_before_optimizer_step")
if pl_module.automatic_optimization:
result = True
if model.automatic_optimization:
result = lambda_closure()
if result is None:
# lambda_closure returning None indicates that backward has been skipped
return False
self.scaler.step(optimizer)
self.scaler.update()
self.scaler.unscale_(optimizer)
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# lambda_closure returning None indicates that backward has been skipped
if result is not None:
self.scaler.step(optimizer)
self.scaler.update()
return False

@contextmanager
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,14 @@ def post_backward(

def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
return True

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,13 @@ def on_after_backward(self):
for callback in self.callbacks:
callback.on_after_backward(self, self.lightning_module)

def on_before_optimizer_step(self, optimizer, optimizer_idx):
"""
Called after on_after_backward() once the gradient is accumulated and before optimizer.step().
"""
for callback in self.callbacks:
callback.on_before_optimizer_step(self, self.lightning_module, optimizer, optimizer_idx)

def on_before_zero_grad(self, optimizer):
"""
Called after optimizer.step() and before optimizer.zero_grad().
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class FxValidator:
on_configure_sharded_model=None,
on_before_backward=dict(on_step=(False, True), on_epoch=(False, True)),
on_after_backward=dict(on_step=(False, True), on_epoch=(False, True)),
on_before_optimizer_step=dict(on_step=(False, True), on_epoch=(False, True)),
on_before_zero_grad=dict(on_step=(False, True), on_epoch=(False, True)),
on_init_start=None,
on_init_end=None,
Expand Down
13 changes: 11 additions & 2 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ def _auto_train_batch(trainer, model, batches, device=torch.device('cpu'), curre
using_native_amp = kwargs.get('amp_backend') == 'native'
using_deepspeed = kwargs.get('plugins') == 'deepspeed'
out = []
on_before_optimizer_step = [
dict(name='Callback.on_before_optimizer_step', args=(trainer, model, ANY, 0)),
dict(name='on_before_optimizer_step', args=(ANY, 0)),
]
for i in range(batches):
out.extend([
dict(name='on_before_batch_transfer', args=(ANY, 0)),
Expand All @@ -308,7 +312,10 @@ def _auto_train_batch(trainer, model, batches, device=torch.device('cpu'), curre
dict(name='Callback.on_batch_start', args=(trainer, model)),
dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_train_batch_start', args=(ANY, i, 0)),
# TODO: `on_before_optimizer_step`
# these are before the training step because
# they are not part of the `training_step_and_backward` closure, however,
# with native amp, the closure is run first and then the optimizer step.
*(on_before_optimizer_step if not using_native_amp else []),
dict(name='forward', args=(ANY, )),
dict(name='training_step', args=(ANY, i)),
dict(name='training_step_end', args=(dict(loss=ANY), )),
Expand All @@ -321,6 +328,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device('cpu'), curre
*([dict(name='backward', args=(ANY, ANY, 0))] if not using_deepspeed else []),
dict(name='Callback.on_after_backward', args=(trainer, model)),
dict(name='on_after_backward'),
*(on_before_optimizer_step if using_native_amp else []),
dict(
name='optimizer_step',
args=(current_epoch, i, ANY, 0, ANY),
Expand Down Expand Up @@ -354,7 +362,8 @@ def _manual_train_batch(trainer, model, batches, device=torch.device('cpu'), **k
dict(name='on_after_backward'),
# `manual_backward` calls the previous 3
dict(name='manual_backward', args=(ANY, )),
# TODO: `on_before_optimizer_step`
dict(name='Callback.on_before_optimizer_step', args=(trainer, model, ANY, 0)),
dict(name='on_before_optimizer_step', args=(ANY, 0)),
dict(name='training_step', args=(ANY, i)),
dict(name='training_step_end', args=(dict(loss=ANY), )),
dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
Expand Down
7 changes: 1 addition & 6 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,7 @@ def test_amp_apex_ddp(

class GradientUnscaleBoringModel(BoringModel):

def on_after_backward(self):
# TODO: replace with `on_before_optimizer_step` so we don't need to check accumulate and unscale manually
if self.trainer.fit_loop.should_accumulate():
return
opt = self.optimizers()
self.trainer.precision_plugin.scaler.unscale_(opt)
def on_before_optimizer_step(self, *_):
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
if not (torch.isinf(norm) or torch.isnan(norm)):
assert norm.item() < 15.
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_fx_validator(tmpdir):
callbacks_func = [
'on_before_backward',
'on_after_backward',
'on_before_optimizer_step',
'on_batch_end',
'on_batch_start',
'on_before_accelerator_backend_setup',
Expand Down Expand Up @@ -124,6 +125,7 @@ def test_fx_validator(tmpdir):
# creating allowed condition
allowed = (
is_stage or "batch" in func_name or "epoch" in func_name or "grad" in func_name or "backward" in func_name
or "optimizer_step" in func_name
)
allowed = (
allowed and "pretrain" not in func_name and "predict" not in func_name
Expand Down

0 comments on commit 1b06edf

Please sign in to comment.