From 60192dcd794e420cd40e75afee06374bb5be3954 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 10 Feb 2021 18:28:05 +0000 Subject: [PATCH 1/2] Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator --- pytorch_lightning/accelerators/accelerator_connector.py | 9 ++++++--- tests/models/test_gpu.py | 9 ++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 7af53bc896b46..af215f6accf27 100755 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -113,7 +113,6 @@ def __init__( self.gpus = pick_multiple_gpus(gpus) self.parallel_device_ids = device_parser.parse_gpu_ids(self.gpus) - self.root_gpu = device_parser.determine_root_gpu_device(self.parallel_device_ids) self.set_distributed_mode() self.configure_slurm_ddp() @@ -276,6 +275,10 @@ def parallel_devices(self): devices = [torch.device("cpu")] * self.num_processes return devices + @property + def root_gpu(self) -> int: + return self.accelerator.root_device.index + @property def is_using_torchelastic(self): te_flags_passed = "WORLD_SIZE" in os.environ and ("GROUP_RANK" in os.environ or "NODE_RANK" in os.environ) @@ -375,7 +378,8 @@ def select_training_type_plugin(self): elif self.on_tpu: plugin = SingleTPUPlugin(self.tpu_id) else: - plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu")) + single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids) + plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu")) return plugin def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin: @@ -525,7 +529,6 @@ def _set_horovod_backend(self): if self.on_gpu: # Horovod assigns one local GPU per process self.parallel_device_ids = list(range(hvd.local_size())) - self.root_gpu = hvd.local_rank() else: self.num_processes = hvd.local_size() diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 4f0f548a32d28..61c1fb7768393 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -68,6 +68,10 @@ def mocked_device_count(monkeypatch): def device_count(): return PRETEND_N_OF_GPUS + def is_available(): + return True + + monkeypatch.setattr(torch.cuda, 'is_available', is_available) monkeypatch.setattr(torch.cuda, 'device_count', device_count) @@ -104,12 +108,7 @@ def test_trainer_num_gpu_0(mocked_device_count_0, gpus, expected_num_gpus, distr @pytest.mark.gpus_param_tests @pytest.mark.parametrize(['gpus', 'expected_root_gpu', "distributed_backend"], [ - pytest.param(None, None, "ddp", id="None is None"), - pytest.param(0, None, "ddp", id="O gpus, expect gpu root device to be None."), pytest.param(1, 0, "ddp", id="1 gpu, expect gpu root device to be 0."), - pytest.param(-1, 0, "ddp", id="-1 - use all gpus, expect gpu root device to be 0."), - pytest.param('-1', 0, "ddp", id="'-1' - use all gpus, expect gpu root device to be 0."), - pytest.param(3, 0, "ddp", id="3 gpus, expect gpu root device to be 0.(backend:ddp)") ]) def test_root_gpu_property(mocked_device_count, gpus, expected_root_gpu, distributed_backend): assert Trainer(gpus=gpus, accelerator=distributed_backend).root_gpu == expected_root_gpu From 07b836ab69eda8bd566150d390b53f15f8333a5a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 10 Feb 2021 18:48:17 +0000 Subject: [PATCH 2/2] Add missing tests back --- tests/models/test_gpu.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 61c1fb7768393..1c3e4b284b2e2 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -108,7 +108,12 @@ def test_trainer_num_gpu_0(mocked_device_count_0, gpus, expected_num_gpus, distr @pytest.mark.gpus_param_tests @pytest.mark.parametrize(['gpus', 'expected_root_gpu', "distributed_backend"], [ + pytest.param(None, None, "ddp", id="None is None"), + pytest.param(0, None, "ddp", id="O gpus, expect gpu root device to be None."), pytest.param(1, 0, "ddp", id="1 gpu, expect gpu root device to be 0."), + pytest.param(-1, 0, "ddp", id="-1 - use all gpus, expect gpu root device to be 0."), + pytest.param('-1', 0, "ddp", id="'-1' - use all gpus, expect gpu root device to be 0."), + pytest.param(3, 0, "ddp", id="3 gpus, expect gpu root device to be 0.(backend:ddp)") ]) def test_root_gpu_property(mocked_device_count, gpus, expected_root_gpu, distributed_backend): assert Trainer(gpus=gpus, accelerator=distributed_backend).root_gpu == expected_root_gpu