Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

added on_backward trainer callback #5249

Merged
merged 5 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `on_backward` training callback which allows for control over backpropagation and gradient manipulation.

### Fixed

- Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs
Expand Down
1 change: 1 addition & 0 deletions allennlp/training/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from allennlp.training.callbacks.tensorboard import TensorBoardCallback
from allennlp.training.callbacks.track_epoch import TrackEpochCallback
from allennlp.training.callbacks.wandb import WandBCallback
from allennlp.training.callbacks.backward import MixedPrecisionBackwardCallback, OnBackwardException
40 changes: 40 additions & 0 deletions allennlp/training/callbacks/backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Dict, TYPE_CHECKING
import torch

from allennlp.training.callbacks.callback import TrainerCallback

if TYPE_CHECKING:
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer


@TrainerCallback.register("mixed_precision_backward")
class MixedPrecisionBackwardCallback(TrainerCallback):
"""
Performs backpropagation for mixed precision training.
"""

def on_backward(
self,
trainer: "GradientDescentTrainer",
batch_outputs: Dict[str, torch.Tensor],
backward_called: bool,
**kwargs
) -> bool:
if backward_called:
raise OnBackwardException()
trainer._scaler.scale(batch_outputs["loss"]).backward() # type: ignore
return True


class OnBackwardException(Exception):
"""
The exception type raised if an `on_backward` callback
attempts to call `backward` when `backward_called` is `True`.
"""

def __init__(self, message="") -> None:
super().__init__(
"Backpropagation has already been performed"
"and the computation graph has been erased, so"
"calling `loss.backward` is not permitted. " + message
)
17 changes: 16 additions & 1 deletion allennlp/training/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Dict, Any, Optional, TYPE_CHECKING
import torch

from allennlp.common import Registrable
from allennlp.data import TensorDict
Expand All @@ -12,7 +13,7 @@ class TrainerCallback(Registrable):
"""
A general callback object that handles multiple events.

This class has `on_batch`, `on_epoch`, and `on_end` methods, corresponding to
This class has `on_backward`, `on_batch`, `on_epoch`, and `on_end` methods, corresponding to
each callback type. Each one receives the state of the wrapper object as `self`.
This enables easier state sharing between related callbacks.

Expand All @@ -33,6 +34,20 @@ def on_start(
"""
self.trainer = trainer

def on_backward(
self,
trainer: "GradientDescentTrainer",
batch_outputs: Dict[str, torch.Tensor],
backward_called: bool,
**kwargs,
) -> bool:
"""
This callback hook performs backpropagation and allows for gradient manipulation.
`backward_called` indicates if `loss.backward` has been called prior to this callback.
`on_backward` should return `True` if and only if `loss.backward` is called in its body.
"""
return False

def on_batch(
self,
trainer: "GradientDescentTrainer",
Expand Down
18 changes: 13 additions & 5 deletions allennlp/training/gradient_descent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from allennlp.models.model import Model
from allennlp.training.callbacks import ConsoleLoggerCallback
from allennlp.training.callbacks.confidence_checks import ConfidenceChecksCallback
from allennlp.training.callbacks.backward import MixedPrecisionBackwardCallback
from allennlp.training.checkpointer import Checkpointer
from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler
from allennlp.training.metric_tracker import MetricTracker
Expand Down Expand Up @@ -148,7 +149,7 @@ class GradientDescentTrainer(Trainer):
parameters. This is necessary because we want the saved model to perform as well as the validated
model if we load it later. But this may cause problems if you restart the training from checkpoint.

callbacks : `List[Lazy[TrainerCallback]]`, optional (default = `None`)
callbacks : `List[TrainerCallback]`, optional (default = `None`)
A list of callbacks that can be called at certain events: e.g. each batch, epoch, and at the start
and end of training, etc.

Expand Down Expand Up @@ -469,10 +470,17 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
batch_reg_loss = reg_loss.item()
train_reg_loss += batch_reg_loss # type: ignore

if self._scaler is not None:
self._scaler.scale(loss).backward()
else:
loss.backward()
backward_called = False
for callback in self._callbacks:
backward_called |= callback.on_backward(self, batch_outputs, backward_called)
if not backward_called:
if self._scaler is not None:
MixedPrecisionBackwardCallback(self._serialization_dir).on_backward(
self, batch_outputs, backward_called
)
else:
loss.backward()

if len(batch_group_outputs) <= 0:
continue

Expand Down
74 changes: 74 additions & 0 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TensorBoardCallback,
ConfidenceChecksCallback,
ConsoleLoggerCallback,
OnBackwardException,
)
from allennlp.training.callbacks.confidence_checks import ConfidenceCheckError
from allennlp.training.learning_rate_schedulers import CosineWithRestarts
Expand Down Expand Up @@ -127,6 +128,26 @@ def setup_method(self):
self.validation_data_loader.index_with(self.vocab)


class ZeroGradientsBackwardCallback(TrainerCallback):
"""
Zeros all gradients after backpropagation.
"""

def on_backward(
self,
trainer: "GradientDescentTrainer",
batch_outputs: Dict[str, torch.Tensor],
backward_called: bool,
**kwargs,
) -> bool:
if backward_called:
raise OnBackwardException()
batch_outputs["loss"].backward()
for param in trainer.model.parameters():
param.grad.data.zero_()
return True


class TestTrainer(TrainerTestBase):
def test_trainer_can_run(self):
trainer = GradientDescentTrainer(
Expand Down Expand Up @@ -168,6 +189,59 @@ def test_trainer_can_run(self):
assert isinstance(metrics["peak_worker_0_memory_MB"], float)
assert metrics["peak_worker_0_memory_MB"] > 0

def test_train_zero_gradients(self):
weights = {}
for name, param in self.model.named_parameters():
weights[name] = param.data.clone()

trainer = GradientDescentTrainer(
self.model,
self.optimizer,
self.data_loader,
num_epochs=2,
validation_data_loader=self.validation_data_loader,
callbacks=[ZeroGradientsBackwardCallback(serialization_dir=self.TEST_DIR)],
)
trainer.train()

# weights should be the same
for name, param in self.model.named_parameters():
assert torch.equal(weights[name], param.data)

def test_two_backward_callbacks(self):
class SecondBackwardCallback(TrainerCallback):
"""
Changes all gradients to 1 after backpropagation.
"""

def on_backward(
self,
trainer: "GradientDescentTrainer",
batch_outputs: Dict[str, torch.Tensor],
backward_called: bool,
**kwargs,
) -> bool:
if backward_called:
raise OnBackwardException()
batch_outputs["loss"].backward()
for param in trainer.model.parameters():
param.grad = torch.ones_like(param.grad, device=param.grad.device)
return True

with pytest.raises(OnBackwardException):
trainer = GradientDescentTrainer(
self.model,
self.optimizer,
self.data_loader,
num_epochs=2,
validation_data_loader=self.validation_data_loader,
callbacks=[
ZeroGradientsBackwardCallback(serialization_dir=self.TEST_DIR),
SecondBackwardCallback(serialization_dir=self.TEST_DIR),
],
)
trainer.train()

def test_trainer_can_run_exponential_moving_average(self):
moving_average = ExponentialMovingAverage(self.model.named_parameters(), decay=0.9999)
trainer = GradientDescentTrainer(
Expand Down