Skip to content

Commit

Permalink
[hotfix] ddp + manual_optimisation (#4976)
Browse files Browse the repository at this point in the history
* Rely on ddp plugin for blocking sync behaviour, and skip if we're using manual optimization

* debug

* Revert "debug"

This reverts commit ccca6b6

* Expose manual reduce for automatic optimization

* Add input arguments

* Enable parity test

* clean imports

* Expose hook after to ensure we reset

* Fix naming

* add

* fix test

* resolve on comments

* typo

* Update tests/trainer/optimization/test_manual_optimization.py

Co-authored-by: Jirka Borovec <[email protected]>

* Update tests/trainer/optimization/test_manual_optimization.py

Co-authored-by: Jirka Borovec <[email protected]>

* update on comments

* resolve comments

Co-authored-by: SeanNaren <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Dec 7, 2020
1 parent 68ba493 commit 2393474
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 17 deletions.
8 changes: 4 additions & 4 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
)


@pytest.mark.skip(reason="Currently DDP manual optimization is broken due to no reduce within training step.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
Expand Down Expand Up @@ -182,16 +181,17 @@ def training_step(self, batch, batch_idx, optimizer_idx):
loss_1 = self.step(batch)

self.manual_backward(loss_1, opt_a)
self.manual_optimizer_step(opt_a)
opt_a.step()

# fake discriminator
loss_2 = self.step(batch[0])

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)
self.manual_optimizer_step(opt_b)
# todo: understand why synchronization breaks there.
# self.manual_backward(loss_2, opt_a, retain_graph=True)
opt_b.step()

assert self.layer.weight.grad is None or torch.all(self.layer.weight.grad == 0)

Expand Down
21 changes: 21 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional, Union

Expand Down Expand Up @@ -86,6 +87,12 @@ def process_dataloader(self, dataloader):
return dataloader

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
automatic_optimization = self.trainer.train_loop.automatic_optimization

if not automatic_optimization and self.ddp_plugin is not None:
# Manually prepare for reduce as user calling backwards manually
self.ddp_plugin.on_before_manual_backward(self.trainer.model, closure_loss)

if self.trainer.precision == 16:
closure_loss = self.trainer.precision_connector.backend.backward(
closure_loss, optimizer, opt_idx, *args, **kwargs
Expand All @@ -97,6 +104,10 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):

# once backward has been applied, release graph
closure_loss = closure_loss.detach()

if not automatic_optimization and self.ddp_plugin is not None:
# Manually prepare for reduce as user calling backwards manually
self.ddp_plugin.on_after_manual_backward(self.trainer.model)
return closure_loss

def clip_gradients(self, optimizer, clip_val=None):
Expand Down Expand Up @@ -211,6 +222,16 @@ def __setstate__(self, d):
def on_save(self, checkpoint):
return checkpoint

@contextmanager
def block_ddp_plugin_sync_behaviour(self):
"""
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
cm = self.ddp_plugin.block_backward_sync(self.trainer.model) if self.ddp_plugin else None
yield cm


# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
class BackendType(Enum):
Expand Down
15 changes: 12 additions & 3 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def parallel_apply(self, replicas, inputs, kwargs):

def forward(self, *inputs, **kwargs): # pragma: no-cover
self._sync_params()
self.reducer_reset_hooks()
fx_called: str = ''

if self.device_ids:
Expand Down Expand Up @@ -194,6 +195,15 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover
else:
output = self.module.validation_step(*inputs, **kwargs)

if not self._reducer_prepared_for_backwards:
self.reducer_prepare_for_backwards(output)

if output is None:
warn_missing_output(f'{fx_called} returned None. Did you forget to return an output')
return output

def reducer_prepare_for_backwards(self, output):
self._reducer_prepared_for_backwards = True
if torch.is_grad_enabled():
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
Expand All @@ -205,9 +215,8 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover
else:
self.reducer.prepare_for_backward([])

if output is None:
warn_missing_output(f'{fx_called} returned None. Did you forget to re')
return output
def reducer_reset_hooks(self):
self._reducer_prepared_for_backwards = False


def warn_missing_output(fx_called):
Expand Down
18 changes: 17 additions & 1 deletion pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Any, Dict, List, Optional, Union
from contextlib import contextmanager
from typing import Any, Dict, List, Union, Optional

import torch.distributed as torch_distrib
from torch.optim import Optimizer
Expand Down Expand Up @@ -132,3 +133,18 @@ def get_model_from_plugin(
if isinstance(model, LightningDistributedDataParallel):
return model.module
return model

@contextmanager
def block_backward_sync(self, model: LightningDistributedDataParallel):
"""
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
yield model.no_sync()

def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
model.reducer_prepare_for_backwards(output)

def on_after_manual_backward(self, model: LightningDistributedDataParallel):
model.reducer_reset_hooks()
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from typing import List, Optional, Union, Any

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
Expand Down Expand Up @@ -94,3 +94,9 @@ def required_plugins(self, amp_backend: AMPType, trainer) -> list:
if amp_backend == AMPType.NATIVE:
return [ShardedNativeAMPPlugin(trainer=trainer)]
return []

def on_before_manual_backward(self, model: 'LightningShardedDataParallel', output: Any):
pass

def on_after_manual_backward(self, model: 'LightningShardedDataParallel'):
pass
28 changes: 23 additions & 5 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,9 +679,15 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# calculate loss (train step + train step end)
# -------------------

# perform dpp sync only when performing optimizer_step
# automatic_optimization=True: perform dpp sync only when performing optimizer_step
# automatic_optimization=False: don't block synchronization here
with self.block_ddp_sync_behaviour():
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)
self.training_step_and_backward(
split_batch,
batch_idx,
opt_idx,
optimizer,
self.trainer.hiddens)

batch_outputs = self._process_closure_result(
batch_outputs=batch_outputs,
Expand Down Expand Up @@ -743,10 +749,22 @@ def train_step_and_backward_closure():

@contextmanager
def block_ddp_sync_behaviour(self):
if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel):
yield self.trainer.model.no_sync()
"""
automatic_optimization = True
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
automatic_optimization = False
do not block ddp gradient sync when using manual optimization
as gradients are needed within the training step
Returns: context manager with sync behaviour off
"""
if self.trainer.accelerator_backend is not None and self.automatic_optimization:
yield self.trainer.accelerator_backend.block_ddp_plugin_sync_behaviour()
else:
yield
yield None

def _process_closure_result(
self, batch_outputs: list, opt_idx: int
Expand Down
4 changes: 2 additions & 2 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

export PL_RUNNING_SPECIAL_TESTS=1
# Running special tests
export PL_RUNNING_SPECIAL_TESTS=1
DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no"
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp
111 changes: 110 additions & 1 deletion tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pytest
import torch
import torch.distributed as torch_distrib
import torch.nn.functional as F

from pytorch_lightning import Trainer, seed_everything
Expand Down Expand Up @@ -862,7 +863,7 @@ def dis_closure():
self.manual_backward(loss_dis, opt_dis)

# this will accumulate gradients for 2 batches and then call opt_gen.step()
opt_gen.step(closure=gen_closure, make_optimizer_step=batch_idx % 2 == 0, optim='sgd')
opt_gen.step(closure=gen_closure, make_optimizer_step=(batch_idx % 2 == 0), optim='sgd')

# update discriminator every 4 baches
# therefore, no gradient accumulation for discriminator
Expand Down Expand Up @@ -904,6 +905,114 @@ def configure_optimizers(self):
mock_adam_step.assert_has_calls(expected_calls)


@patch("torch.optim.Adam.step")
@patch("torch.optim.SGD.step")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
def test_step_with_optimizer_closure_with_different_frequencies_ddp(mock_sgd_step, mock_adam_step, tmpdir):
"""
Tests that `step` works with optimizer_closure and different accumulated_gradient frequency
"""
os.environ['PL_DEV_DEBUG'] = '1'

class TestModel(BoringModel):

def loss_ones(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

def loss_zeros(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.zeros_like(prediction))

def manual_sync_grad(self) -> bool:
torch_distrib.all_reduce(self.layer.weight.grad.data, async_op=False)
return True

def training_step(self, batch, batch_idx, optimizer_idx):

# emulate gans training
opt_gen, opt_dis = self.optimizers()

# Note: Be careful, don't log on the same key in self.log in both closure
# as they will be aggregated together on epoch_end

world_size = torch_distrib.get_world_size(torch_distrib.group.WORLD)
assert world_size == 2

def compute_loss():
x = batch[0]
x = F.dropout(x, 0.1)
predictions = self(x)
predictions = F.dropout(predictions, 0.1)
loss_ones = self.loss_ones(None, predictions)
loss_zeros = self.loss_zeros(None, predictions)
return loss_ones, loss_zeros

def make_manual_backward(loss, opt, retain_graph=False):
self.manual_backward(loss, opt, retain_graph=retain_graph)
grad_clone = self.layer.weight.grad.clone()
assert self.manual_sync_grad()
self.layer.weight.grad /= world_size
assert torch.equal(self.layer.weight.grad, grad_clone)

def gen_closure():
loss_ones_gen, loss_zeros = compute_loss()
make_manual_backward(loss_ones_gen, opt_gen, retain_graph=True)
make_manual_backward(loss_ones_gen, opt_gen)

def dis_closure():
loss_ones_gen, loss_zeros = compute_loss()
make_manual_backward(loss_ones_gen, opt_dis, retain_graph=True)
make_manual_backward(loss_ones_gen, opt_dis)

# this will accumulate gradients for 2 batches and then call opt_gen.step()
opt_gen.step(closure=gen_closure, make_optimizer_step=batch_idx % 2 == 0, optim='sgd')

# update discriminator every 4 baches
# therefore, no gradient accumulation for discriminator
if batch_idx % 4 == 0 :
# Note: Set make_optimizer_step to True or it will use by default
# Trainer(accumulate_grad_batches=x)
opt_dis.step(closure=dis_closure, make_optimizer_step=True, optim='adam')

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2

def configure_optimizers(self):
optimizer_gen = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001)
return [optimizer_gen, optimizer_dis]

seed_everything(42)

model = TestModel()
model.val_dataloader = None
model.training_epoch_end = None

limit_train_batches = 8
trainer = Trainer(
automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
accumulate_grad_batches=2,
enable_pl_optimizer=True,
gpus=2,
accelerator="ddp",
)

trainer.fit(model)
expected_calls = [call(closure=ANY, optim='sgd')] * 4
mock_sgd_step.assert_has_calls(expected_calls)

expected_calls = [call(closure=ANY, optim='adam')] * 2
mock_adam_step.assert_has_calls(expected_calls)


def test_step_with_misconfiguraiton_error_when_overriding_optimizer_zero_grad(tmpdir):
"""
Tests that `optimizer_zero_grad` in manual_optimization triggers a MisconfigurationException
Expand Down

0 comments on commit 2393474

Please sign in to comment.