diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 7af53bc896b46..af215f6accf27 100755 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -113,7 +113,6 @@ def __init__( self.gpus = pick_multiple_gpus(gpus) self.parallel_device_ids = device_parser.parse_gpu_ids(self.gpus) - self.root_gpu = device_parser.determine_root_gpu_device(self.parallel_device_ids) self.set_distributed_mode() self.configure_slurm_ddp() @@ -276,6 +275,10 @@ def parallel_devices(self): devices = [torch.device("cpu")] * self.num_processes return devices + @property + def root_gpu(self) -> int: + return self.accelerator.root_device.index + @property def is_using_torchelastic(self): te_flags_passed = "WORLD_SIZE" in os.environ and ("GROUP_RANK" in os.environ or "NODE_RANK" in os.environ) @@ -375,7 +378,8 @@ def select_training_type_plugin(self): elif self.on_tpu: plugin = SingleTPUPlugin(self.tpu_id) else: - plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu")) + single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids) + plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu")) return plugin def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin: @@ -525,7 +529,6 @@ def _set_horovod_backend(self): if self.on_gpu: # Horovod assigns one local GPU per process self.parallel_device_ids = list(range(hvd.local_size())) - self.root_gpu = hvd.local_rank() else: self.num_processes = hvd.local_size() diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 4f0f548a32d28..1c3e4b284b2e2 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -68,6 +68,10 @@ def mocked_device_count(monkeypatch): def device_count(): return PRETEND_N_OF_GPUS + def is_available(): + return True + + monkeypatch.setattr(torch.cuda, 'is_available', is_available) monkeypatch.setattr(torch.cuda, 'device_count', device_count)