Skip to content

Commit

Permalink
ref: refactored gpu backend __step (#3120)
Browse files Browse the repository at this point in the history
* refactored gpu backend __step

* refactored gpu backend __step

* refactored gpu backend __step

* refactored gpu backend __step
  • Loading branch information
williamFalcon authored Aug 24, 2020
1 parent 527b9dc commit 8d7ca5c
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 20 deletions.
36 changes: 34 additions & 2 deletions pytorch_lightning/accelerators/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
import torch
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.accelerators.base_backend import Accelerator

try:
from apex import amp
except ImportError:
amp = None


class GPUBackend(object):
class GPUBackend(Accelerator):
amp_backend: AMPType

def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)

def setup(self, model):

Expand All @@ -51,6 +52,37 @@ def train(self, model):
results = self.trainer.run_pretrain_routine(model)
return results

def training_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.training_step(*args)
return output

def validation_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.validation_step(*args)
return output

def test_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.test_step(*args)
return output

def to_device(self, batch):
gpu_id = 0
if isinstance(self.trainer.data_parallel_device_ids, list):
gpu_id = self.trainer.data_parallel_device_ids[0]

# Don't copy the batch since there is a single gpu that the batch could
# be referenced from and if there are multiple optimizers the batch will
# wind up copying it to the same device repeatedly.
return self.batch_to_device(batch, gpu_id)

def _setup_nvidia_apex(self, model: LightningModule):
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,12 +668,11 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode:

# single GPU data transfer
if self.use_single_gpu:
# for single GPU put inputs on gpu manually
root_gpu = 0
if isinstance(self.data_parallel_device_ids, list):
root_gpu = self.data_parallel_device_ids[0]
batch = self.transfer_batch_to_gpu(batch, root_gpu)
args[0] = batch
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
return output

# TPU data transfer
if self.use_tpu:
Expand Down
11 changes: 1 addition & 10 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,16 +1201,7 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):

# single GPU forward
elif self.use_single_gpu:
gpu_id = 0
if isinstance(self.data_parallel_device_ids, list):
gpu_id = self.data_parallel_device_ids[0]

# Don't copy the batch since there is a single gpu that the batch could
# be referenced from and if there are multiple optimizers the batch will
# wind up copying it to the same device repeatedly.
batch = self.transfer_batch_to_gpu(batch, gpu_id)
args[0] = batch
output = self.model.training_step(*args)
output = self.accelerator_backend.training_step(args)

# TPU support
elif self.use_tpu:
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def seed_everything(seed: Optional[int] = None) -> int:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return seed


Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,14 @@ def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_result_obj_predictions_ddp_spawn(tmpdir):
seed_everything(4321)

distributed_backend = 'ddp_spawn'
option = 0

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

seed_everything(4321)

dm = TrialMNISTDataModule(tmpdir)

prediction_file = Path('predictions.pt')
Expand Down

0 comments on commit 8d7ca5c

Please sign in to comment.