Skip to content

Commit

Permalink
Disable batch transfer in DP mode (#6098)
Browse files Browse the repository at this point in the history
* add exceptions and test

* hook

* fix

* clean up

* clean up

* regex

* regex

* docs

* rev

* comment and docs

* chlog

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <[email protected]>

* Apply suggestions from code review

Co-authored-by: chaton <[email protected]>

* Monkey-patch device count

* docs

* pep

* api_change

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: chaton <[email protected]>
  • Loading branch information
3 people authored Mar 11, 2021
1 parent e886d55 commit c53edce
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 20 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))


- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))


## [1.2.0] - 2021-02-18

### Added
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if TYPE_CHECKING:
Expand Down Expand Up @@ -48,3 +49,11 @@ def set_nvidia_flags() -> None:
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]")

def to_device(self, batch: Any) -> Any:
# no need to transfer batch to device in DP mode
# TODO: Add support to allow batch transfer to device in Lightning for DP mode.
if not isinstance(self.training_type_plugin, DataParallelPlugin):
batch = super().to_device(batch)

return batch
21 changes: 16 additions & 5 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] =
Note:
This hook only runs on single GPU training and DDP (no data-parallel).
If you need multi-GPU support for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
Data-Parallel support will come in near future.
Args:
batch: A batch of data that needs to be transferred to a new device.
Expand All @@ -638,6 +635,10 @@ def transfer_batch_to_device(self, batch, device):
batch = super().transfer_batch_to_device(data, device)
return batch
Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.
See Also:
- :meth:`move_data_to_device`
- :meth:`apply_to_collection`
Expand All @@ -649,10 +650,11 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
"""
Override to alter or apply batch augmentations to your batch before it is transferred to the device.
.. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future.
.. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true index in the future.
Note:
This hook only runs on single GPU training and DDP (no data-parallel).
Data-Parallel support will come in near future.
Args:
batch: A batch of data that needs to be altered or augmented.
Expand All @@ -667,6 +669,10 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
batch['x'] = transforms(batch['x'])
return batch
Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.
See Also:
- :meth:`on_after_batch_transfer`
- :meth:`transfer_batch_to_device`
Expand All @@ -681,6 +687,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
Note:
This hook only runs on single GPU training and DDP (no data-parallel).
Data-Parallel support will come in near future.
Args:
batch: A batch of data that needs to be altered or augmented.
Expand All @@ -695,6 +702,10 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
batch['x'] = gpu_transforms(batch['x'])
return batch
Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.
See Also:
- :meth:`on_before_batch_transfer`
- :meth:`transfer_batch_to_device`
Expand Down
31 changes: 17 additions & 14 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
# set up the passed in dataloaders (if needed)
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
self.attach_datamodule(model, datamodule)
self._validate_data_hooks(model)

def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
Expand All @@ -97,6 +98,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)

def _validate_data_hooks(self, model):
# Raise Misconfiguration exception since these hooks are not supported in DP mode
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')

def attach_dataloaders(
self,
model,
Expand Down Expand Up @@ -127,22 +136,16 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = N
if datamodule:

# Override loader hooks
if is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
if is_overridden('val_dataloader', datamodule):
model.val_dataloader = datamodule.val_dataloader
if is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader
if is_overridden('predict_dataloader', datamodule):
model.predict_dataloader = datamodule.predict_dataloader
dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader')
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))

# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
if is_overridden('on_before_batch_transfer', datamodule):
model.on_before_batch_transfer = datamodule.on_before_batch_transfer
if is_overridden('transfer_batch_to_device', datamodule):
model.transfer_batch_to_device = datamodule.transfer_batch_to_device
if is_overridden('on_after_batch_transfer', datamodule):
model.on_after_batch_transfer = datamodule.on_after_batch_transfer
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if is_overridden(hook, datamodule):
setattr(model, hook, getattr(datamodule, hook))

self.trainer.datamodule = datamodule
datamodule.trainer = self.trainer
Expand Down
53 changes: 53 additions & 0 deletions tests/accelerators/test_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl
import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core import memory
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -132,6 +135,56 @@ def training_epoch_end(self, outputs):
assert outputs[0]["reduce_float"].item() == 0.5 # mean([0., 1.]) = 0.5


def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch):
"""
Test that an exception is raised when overriding batch_transfer_hooks in DP model.
"""
monkeypatch.setattr("torch.cuda.device_count", lambda: 2)

class CustomModel(BoringModel):

def transfer_batch_to_device(self, batch, device):
batch = batch.to(device)
return batch

trainer_options = dict(
default_root_dir=tmpdir,
max_steps=7,
gpus=[0, 1],
accelerator='dp',
)

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `transfer_batch_to_device` is not .* in DP'):
trainer.fit(model)

class CustomModel(BoringModel):

def on_before_batch_transfer(self, batch, dataloader_idx):
batch += 1
return batch

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `on_before_batch_transfer` is not .* in DP'):
trainer.fit(model)

class CustomModel(BoringModel):

def on_after_batch_transfer(self, batch, dataloader_idx):
batch += 1
return batch

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `on_after_batch_transfer` is not .* in DP'):
trainer.fit(model)


@RunIf(min_gpus=2)
def test_dp_training_step_dict(tmpdir):
""" This test verifies that dp properly reduces dictionaries """
Expand Down

0 comments on commit c53edce

Please sign in to comment.