diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 0a629a6f21c62..0237a1975ccee 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -640,4 +640,8 @@ def determine_root_gpu_device(gpus): # set root gpu root_gpu = gpus[0] + # set cuda device to root gpu + root_device = (torch.device("cuda", root_gpu) if root_gpu >= 0 else torch.device("cpu")) + torch.cuda.set_device(root_device) + return root_gpu diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 1ca088ebbc720..b6e595fb7228d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -421,9 +421,18 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: # single GPU data transfer if self.single_gpu: # for single GPU put inputs on gpu manually - root_gpu = 0 + if isinstance(self.data_parallel_device_ids, list): root_gpu = self.data_parallel_device_ids[0] + + # set cuda device to root gpu + root_device = (torch.device("cuda", root_gpu) + if root_gpu >= 0 else torch.device("cpu")) + torch.cuda.set_device(root_device) + else: + raise RuntimeError( + 'Expected `data_parallel_device_ids` as a list, cannot determine root gpu.' + ) batch = self.transfer_batch_to_gpu(batch, root_gpu) args[0] = batch