From 8d7ca5cd2ca0c6171747fd5b4dd6789542675db3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Aug 2020 09:22:05 -0400 Subject: [PATCH] ref: refactored gpu backend __step (#3120) * refactored gpu backend __step * refactored gpu backend __step * refactored gpu backend __step * refactored gpu backend __step --- pytorch_lightning/accelerators/gpu_backend.py | 36 +++++++++++++++++-- pytorch_lightning/trainer/evaluation_loop.py | 11 +++--- pytorch_lightning/trainer/training_loop.py | 11 +----- pytorch_lightning/utilities/seed.py | 1 + tests/core/test_results.py | 4 +-- 5 files changed, 43 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/accelerators/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py index ea0057dcc13ee..68751cb79c8ca 100644 --- a/pytorch_lightning/accelerators/gpu_backend.py +++ b/pytorch_lightning/accelerators/gpu_backend.py @@ -15,6 +15,7 @@ import torch from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import AMPType +from pytorch_lightning.accelerators.base_backend import Accelerator try: from apex import amp @@ -22,11 +23,11 @@ amp = None -class GPUBackend(object): +class GPUBackend(Accelerator): amp_backend: AMPType def __init__(self, trainer): - self.trainer = trainer + super().__init__(trainer) def setup(self, model): @@ -51,6 +52,37 @@ def train(self, model): results = self.trainer.run_pretrain_routine(model) return results + def training_step(self, args): + batch = args[0] + batch = self.to_device(batch) + args[0] = batch + output = self.trainer.model.training_step(*args) + return output + + def validation_step(self, args): + batch = args[0] + batch = self.to_device(batch) + args[0] = batch + output = self.trainer.model.validation_step(*args) + return output + + def test_step(self, args): + batch = args[0] + batch = self.to_device(batch) + args[0] = batch + output = self.trainer.model.test_step(*args) + return output + + def to_device(self, batch): + gpu_id = 0 + if isinstance(self.trainer.data_parallel_device_ids, list): + gpu_id = self.trainer.data_parallel_device_ids[0] + + # Don't copy the batch since there is a single gpu that the batch could + # be referenced from and if there are multiple optimizers the batch will + # wind up copying it to the same device repeatedly. + return self.batch_to_device(batch, gpu_id) + def _setup_nvidia_apex(self, model: LightningModule): model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) self.trainer.optimizers = optimizers diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index ef48f5a81b32e..ecee841382d9c 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -668,12 +668,11 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: # single GPU data transfer if self.use_single_gpu: - # for single GPU put inputs on gpu manually - root_gpu = 0 - if isinstance(self.data_parallel_device_ids, list): - root_gpu = self.data_parallel_device_ids[0] - batch = self.transfer_batch_to_gpu(batch, root_gpu) - args[0] = batch + if test_mode: + output = self.accelerator_backend.test_step(args) + else: + output = self.accelerator_backend.validation_step(args) + return output # TPU data transfer if self.use_tpu: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index cec3ff6e61a3e..aa1dc3df8cd70 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -1201,16 +1201,7 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens): # single GPU forward elif self.use_single_gpu: - gpu_id = 0 - if isinstance(self.data_parallel_device_ids, list): - gpu_id = self.data_parallel_device_ids[0] - - # Don't copy the batch since there is a single gpu that the batch could - # be referenced from and if there are multiple optimizers the batch will - # wind up copying it to the same device repeatedly. - batch = self.transfer_batch_to_gpu(batch, gpu_id) - args[0] = batch - output = self.model.training_step(*args) + output = self.accelerator_backend.training_step(args) # TPU support elif self.use_tpu: diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index adaecf11f5514..0eebe629a6c2c 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -50,6 +50,7 @@ def seed_everything(seed: Optional[int] = None) -> int: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) return seed diff --git a/tests/core/test_results.py b/tests/core/test_results.py index bbae6e61145be..dbd96d077d5bd 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -121,14 +121,14 @@ def test_result_obj_predictions(tmpdir, test_option, do_train, gpus): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_result_obj_predictions_ddp_spawn(tmpdir): + seed_everything(4321) + distributed_backend = 'ddp_spawn' option = 0 import os os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' - seed_everything(4321) - dm = TrialMNISTDataModule(tmpdir) prediction_file = Path('predictions.pt')