Skip to content

Commit

Permalink
move batch to device before sending it to hooks (#7378)
Browse files Browse the repository at this point in the history
* update train step

* test

* x

* limits

* val

* typeo

* x

* x

* step

* min gpus

* run all loops

* x

* limit test

* profiler

* clean up accelerator code

* move files

* rename

* move tests

* changelog

* reorder callbacks and model hooks

* add test description

* replace unneccessary method

* fix chlog

* adjust batch_to_device for DP Plugin

* update tests for dataloader idx

* unused imports

* hook change

* switch None

* clear memory

* change to None

* None

* None

* memory savings

* remove redundant todo

* hack

* cheat

* Revert "cheat"

This reverts commit a8433bd.

* Revert "hack"

This reverts commit 43a6d1e.

* update new epoch loop

* remove from old loop code

* update chlog

* update hook test

* changelog

* teardown

* integrate changes in new eval loop

* fix hook calls

* add prediction step

* bad merge

* Revert "bad merge"

This reverts commit 4880808.

* fix train batch hook test

* rm -rf _notebooks

* update chlog

* release memory

* fix type

* notebooks mess

* debug

* Revert "debug"

This reverts commit eec4ee2.

* teardown

* fix teardown bug

* debug

* x

* debug

* Revert "debug"

This reverts commit a6e6101.

Revert "debug"

This reverts commit 5ddeaec.

debug


debug


Revert "debug"

This reverts commit 605be74.

Revert "Revert "debug""

This reverts commit a7612d5.

debug


x


x


x


s


tol


x


tol

* Fix changelog

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people authored Jul 5, 2021
1 parent 8193bae commit ea5cfd2
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 50 deletions.
17 changes: 9 additions & 8 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208))


- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145))


- Fixed moving batch to device before sending it to the `on_*_batch_start`/`on_*_batch_end` callbacks and model hooks ([#7378](https://github.com/PyTorchLightning/pytorch-lightning/pull/7378))


- Fixed passing a custom `DDPPlugin` when choosing `accelerator="ddp_cpu"` for the accelerator ([#6208](https://github.com/PyTorchLightning/pytorch-lightning/pull/6208))


## [1.3.8] - 2021-07-01

Expand All @@ -361,13 +369,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed SWA to also work with `IterableDataset` ([#8172](https://github.com/PyTorchLightning/pytorch-lightning/pull/8172))



- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145))


- Fixed passing a custom `DDPPlugin` when choosing `accelerator="ddp_cpu"` for the accelerator ([#6208](https://github.com/PyTorchLightning/pytorch-lightning/pull/6208))


## [1.3.7] - 2021-06-22

### Fixed
Expand All @@ -377,6 +378,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed setting a `DistributedSampler` when using a distributed plugin in a custom accelerator ([#7814](https://github.com/PyTorchLightning/pytorch-lightning/pull/7814))
- Improved `PyTorchProfiler` chrome traces names ([#8009](https://github.com/PyTorchLightning/pytorch-lightning/pull/8009))
- Fixed moving the best score to device in `EarlyStopping` callback for TPU devices ([#7959](https://github.com/PyTorchLightning/pytorch-lightning/pull/7959))
- Fixes access to `callback_metrics` in ddp_spawn ([#7916](https://github.com/PyTorchLightning/pytorch-lightning/pull/7916))


## [1.3.6] - 2021-06-15
Expand All @@ -387,7 +389,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945))
- Fixed setting `worker_init_fn` to seed dataloaders correctly when using DDP ([#7942](https://github.com/PyTorchLightning/pytorch-lightning/pull/7942))
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
- Fixes access to `callback_metrics` in ddp_spawn ([#7916](https://github.com/PyTorchLightning/pytorch-lightning/pull/7916))


## [1.3.5] - 2021-06-08
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: f
@pytest.mark.parametrize(
'cls_model,max_diff_speed,max_diff_memory',
[
(ParityModuleRNN, 0.05, 0.0),
(ParityModuleMNIST, 0.25, 0.0), # todo: lower this thr
(ParityModuleRNN, 0.05, 0.001),
(ParityModuleMNIST, 0.25, 0.001), # todo: lower this thr
]
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
Expand Down
20 changes: 3 additions & 17 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -173,8 +174,8 @@ def batch_to_device(
dataloader_idx: The index of the dataloader to which the batch belongs.
"""
model = self.lightning_module

if model is not None:
if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin):
# no need to transfer batch to device in DP mode
return model._apply_batch_transfer_handler(batch, device, dataloader_idx)

return move_data_to_device(batch, device)
Expand All @@ -195,8 +196,6 @@ def training_step(
- hiddens(:class:`~torch.Tensor`): Passed in if
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
"""
step_kwargs = self.to_device(step_kwargs)

with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
return self.training_type_plugin.training_step(*step_kwargs.values())

Expand All @@ -215,8 +214,6 @@ def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[S
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple val dataloaders used)
"""
step_kwargs = self.to_device(step_kwargs)

with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
return self.training_type_plugin.validation_step(*step_kwargs.values())

Expand All @@ -232,8 +229,6 @@ def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OU
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple test dataloaders used).
"""
step_kwargs = self.to_device(step_kwargs)

with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
return self.training_type_plugin.test_step(*step_kwargs.values())

Expand All @@ -249,8 +244,6 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
- dataloader_idx (int): The index of the dataloader that produced this batch
(only if multiple predict dataloaders used).
"""
step_kwargs = self.to_device(step_kwargs)

with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context():
return self.training_type_plugin.predict_step(*step_kwargs.values())

Expand Down Expand Up @@ -371,13 +364,6 @@ def setup_precision_plugin(self) -> None:
self.optimizers = optimizers
self.schedulers = schedulers

def to_device(self, step_kwargs: Dict[str, Union[Any, int]]) -> Dict[str, Union[Any, int]]:
"""Pushes the batch to the root device"""
step_kwargs['batch'] = self.batch_to_device(
step_kwargs['batch'], self.root_device, dataloader_idx=step_kwargs.get('dataloader_idx', None)
)
return step_kwargs

@property
def amp_backend(self) -> Optional[LightningEnum]:
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# limitations under the License.
import logging
import os
from typing import Any, Dict, Union

import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -51,11 +49,3 @@ def set_nvidia_flags(local_rank: int) -> None:
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def to_device(self, step_kwargs: Dict[str, Union[Any, int]]) -> Dict[str, Union[Any, int]]:
# no need to transfer batch to device in DP mode
# TODO: Add support to allow batch transfer to device in Lightning for DP mode.
if not isinstance(self.training_type_plugin, DataParallelPlugin):
step_kwargs = super().to_device(step_kwargs)

return step_kwargs
4 changes: 4 additions & 0 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def advance(self, batch, batch_idx, dataloader_idx):
if result:
self.batch_outputs[0].append(result.training_step_output)

def teardown(self) -> None:
# release memory
self._remaining_splits = None

def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
"""Gets the number of active optimizers based on their frequency"""
return len(self.get_active_optimizers(batch_idx))
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def advance(
if batch is None:
raise StopIteration

with self.trainer.profiler.profile("evaluation_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)

# hook
self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def advance(
if batch is None:
raise StopIteration

with self.trainer.profiler.profile("predict_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)

with self.trainer.profiler.profile("predict_step"):
self._predict_step(batch, batch_idx, dataloader_idx)

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=self._dataloader_idx)

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx)
self.batches_seen += 1
Expand Down
24 changes: 12 additions & 12 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,13 @@ def _train_batch(trainer, model, batches, current_epoch=0):
out = []
for i in range(batches):
out.extend([
dict(name='on_before_batch_transfer', args=(ANY, 0)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
dict(name='on_after_batch_transfer', args=(ANY, 0)),
# TODO: `on_batch_{start,end}`
dict(name='Callback.on_batch_start', args=(trainer, model)),
dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_train_batch_start', args=(ANY, i, 0)),
dict(name='on_before_batch_transfer', args=(ANY, None)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
dict(name='on_after_batch_transfer', args=(ANY, None)),
dict(name='forward', args=(ANY, )),
dict(name='training_step', args=(ANY, i)),
dict(name='training_step_end', args=(dict(loss=ANY), )),
Expand Down Expand Up @@ -338,12 +338,12 @@ def _eval_batch(fn, trainer, model, batches, key):
outputs = {key: ANY}
for i in range(batches):
out.extend([
dict(name='on_before_batch_transfer', args=(ANY, 0)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
dict(name='on_after_batch_transfer', args=(ANY, 0)),
# TODO: `{,Callback}.on_batch_{start,end}`
dict(name=f'Callback.on_{fn}_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name=f'on_{fn}_batch_start', args=(ANY, i, 0)),
dict(name='on_before_batch_transfer', args=(ANY, None)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
dict(name='on_after_batch_transfer', args=(ANY, None)),
dict(name='forward', args=(ANY, )),
dict(name=f'{fn}_step', args=(ANY, i)),
dict(name=f'{fn}_step_end', args=(outputs, )),
Expand All @@ -358,11 +358,11 @@ def _predict_batch(trainer, model, batches):
for i in range(batches):
out.extend([
# TODO: `{,Callback}.on_batch_{start,end}`
dict(name='on_before_batch_transfer', args=(ANY, 0)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
dict(name='on_after_batch_transfer', args=(ANY, 0)),
dict(name='Callback.on_predict_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_predict_batch_start', args=(ANY, i, 0)),
dict(name='on_before_batch_transfer', args=(ANY, None)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
dict(name='on_after_batch_transfer', args=(ANY, None)),
dict(name='forward', args=(ANY, )),
dict(name='predict_step', args=(ANY, i)),
# TODO: `predict_step_end`
Expand Down Expand Up @@ -777,9 +777,9 @@ def call(hook, fn, *args, **kwargs):
dm = HookedDataModule(called)
trainer.fit(model, datamodule=dm)
batch_transfer = [
dict(name='on_before_batch_transfer', args=(ANY, None)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), None)),
dict(name='on_after_batch_transfer', args=(ANY, None)),
dict(name='on_before_batch_transfer', args=(ANY, 0)),
dict(name='transfer_batch_to_device', args=(ANY, torch.device('cpu'), 0)),
dict(name='on_after_batch_transfer', args=(ANY, 0)),
]
expected = [
dict(name='prepare_data'),
Expand Down
93 changes: 93 additions & 0 deletions tests/trainer/loops/test_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Callback, Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf


class BatchHookObserverCallback(Callback):

def on_train_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device

def on_train_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device

def on_validation_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device

def on_test_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device

def on_test_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device

def on_predict_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device


class BatchHookObserverModel(BoringModel):

def on_train_batch_start(self, batch, *args):
assert batch.device == self.device

def on_train_batch_end(self, outputs, batch, *args):
assert batch.device == self.device

def on_validation_batch_start(self, batch, *args):
assert batch.device == self.device

def on_validation_batch_end(self, outputs, batch, *args):
assert batch.device == self.device

def on_test_batch_start(self, batch, *args):
assert batch.device == self.device

def on_test_batch_end(self, outputs, batch, *args):
assert batch.device == self.device

def on_predict_batch_start(self, batch, *args):
assert batch.device == self.device

def on_predict_batch_end(self, outputs, batch, *args):
assert batch.device == self.device


@RunIf(min_gpus=1)
def test_callback_batch_on_device(tmpdir):
""" Test that the batch object sent to the on_*_batch_start/end hooks is on the right device."""

batch_callback = BatchHookObserverCallback()

model = BatchHookObserverModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
limit_predict_batches=1,
gpus=1,
callbacks=[batch_callback],
)
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)
2 changes: 1 addition & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,7 +1533,7 @@ def __init__(self):

def assert_dataloader_idx_hook(self, dataloader_idx):
if self.trainer.training:
assert dataloader_idx is None
assert dataloader_idx == 0
elif self.trainer.validating:
assert dataloader_idx == (0 if self.val_call_count <= 5 else 1)
elif self.trainer.testing:
Expand Down

0 comments on commit ea5cfd2

Please sign in to comment.