From 979fca0c0e6f1d36ec3f6bf9d8d825a0137e2d17 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 6 Mar 2023 15:26:34 +0100 Subject: [PATCH 1/6] don't call set_world_ranks in xla strategy --- src/lightning/pytorch/strategies/xla.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index 8c60bbaa6608c..3baf91b9c92fc 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -211,14 +211,8 @@ def reduce( def setup_distributed(self) -> None: self._launched = True - self.set_world_ranks() rank_zero_only.rank = self.global_rank - def set_world_ranks(self) -> None: - if self.cluster_environment is None: - return - rank_zero_only.rank = self.cluster_environment.global_rank() - def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: assert self.model is not None with self.precision_plugin.val_step_context(): From 475a113cdcf0ec9877ff020b3c3df745bd1206de Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 6 Mar 2023 15:28:37 +0100 Subject: [PATCH 2/6] update --- src/lightning/pytorch/strategies/xla.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index 3baf91b9c92fc..0f21433953bd3 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -213,6 +213,9 @@ def setup_distributed(self) -> None: self._launched = True rank_zero_only.rank = self.global_rank + def set_world_ranks(self) -> None: + pass + def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: assert self.model is not None with self.precision_plugin.val_step_context(): From f58b3eb85d76b684d5f37a7f8eee305791c4a6ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 13 Apr 2023 16:35:18 +0200 Subject: [PATCH 3/6] fabric and other strategies --- src/lightning/fabric/connector.py | 4 ++-- src/lightning/fabric/strategies/ddp.py | 12 ++++++------ src/lightning/fabric/strategies/deepspeed.py | 3 +-- src/lightning/fabric/strategies/fsdp.py | 12 ++++++------ src/lightning/fabric/strategies/xla.py | 6 ------ src/lightning/pytorch/strategies/ddp.py | 12 ++++++------ src/lightning/pytorch/strategies/deepspeed.py | 3 +-- src/lightning/pytorch/strategies/fsdp.py | 14 ++++++-------- src/lightning/pytorch/strategies/xla.py | 3 +++ 9 files changed, 31 insertions(+), 38 deletions(-) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index faf6107ad5a55..811f72d07fe19 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -507,8 +507,8 @@ def _lazy_init_strategy(self) -> None: self.strategy.parallel_devices = self._parallel_devices if hasattr(self.strategy, "num_nodes"): self.strategy._num_nodes = self._num_nodes_flag - if hasattr(self.strategy, "set_world_ranks"): - self.strategy.set_world_ranks() + if hasattr(self.strategy, "_set_world_ranks"): + self.strategy._set_world_ranks() self.strategy._configure_launcher() if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible: diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index 24b69eefa8509..ad70ee4a4345d 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -177,7 +177,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None: def _setup_distributed(self) -> None: self._set_world_ranks() - rank_zero_only.rank = self.global_rank self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) @@ -186,11 +185,12 @@ def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) def _set_world_ranks(self) -> None: - if self.cluster_environment is None: - return - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - rank_zero_only.rank = self.cluster_environment.global_rank() + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail + # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter + rank_zero_only.rank = self.global_rank def _determine_ddp_device_ids(self) -> Optional[List[int]]: if self.root_device.type == "cpu": diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 747e3bacda3d7..45d3e682a2d3f 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -33,7 +33,7 @@ from lightning.fabric.strategies.strategy import _Sharded from lightning.fabric.utilities.distributed import log from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 -from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn +from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH @@ -580,7 +580,6 @@ def _setup_distributed(self) -> None: ) reset_seed() self._set_world_ranks() - rank_zero_only.rank = self.global_rank self._init_deepspeed_distributed() if not self._config_initialized: self._format_config() diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index df5074d9d96a2..8e506c3e9aa90 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -320,7 +320,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None: def _setup_distributed(self) -> None: reset_seed() self._set_world_ranks() - rank_zero_only.rank = self.global_rank self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) @@ -329,11 +328,12 @@ def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) def _set_world_ranks(self) -> None: - if self.cluster_environment is None: - return - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - rank_zero_only.rank = self.cluster_environment.global_rank() + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail + # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter + rank_zero_only.rank = self.global_rank def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None: diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index dd01da578f1e5..def5089383059 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -93,7 +93,6 @@ def _configure_launcher(self) -> None: def setup_environment(self) -> None: self._launched = True - self._set_world_ranks() rank_zero_only.rank = self.global_rank super().setup_environment() @@ -203,8 +202,3 @@ def remove_checkpoint(self, filepath: _PATH) -> None: @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register("xla", cls, description=cls.__class__.__name__) - - def _set_world_ranks(self) -> None: - if self.cluster_environment is None: - return - rank_zero_only.rank = self.cluster_environment.global_rank() diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index a6899b1c13307..5b783ccdc3225 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -183,7 +183,6 @@ def setup_distributed(self) -> None: log.debug(f"{self.__class__.__name__}: setting up distributed...") reset_seed() self.set_world_ranks() - rank_zero_only.rank = self.global_rank self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) @@ -192,11 +191,12 @@ def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) def set_world_ranks(self) -> None: - if self.cluster_environment is None: - return - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - rank_zero_only.rank = self.cluster_environment.global_rank() + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail + # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter + rank_zero_only.rank = self.global_rank def _register_ddp_hooks(self) -> None: log.debug(f"{self.__class__.__name__}: registering ddp hooks") diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 152b07449a2ec..848b2c14277ac 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -43,7 +43,7 @@ from lightning.pytorch.utilities import GradClipAlgorithmType from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden -from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn, WarningCache +from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn, WarningCache from lightning.pytorch.utilities.types import LRSchedulerConfig, STEP_OUTPUT log = logging.getLogger(__name__) @@ -326,7 +326,6 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option def setup_distributed(self) -> None: reset_seed() self.set_world_ranks() - rank_zero_only.rank = self.global_rank self._init_deepspeed_distributed() if not self._config_initialized: self._format_config() diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 42a1702d7fbd6..319cecca69ca2 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -178,9 +178,6 @@ def setup_environment(self) -> None: # determine which process we are and world size self.set_world_ranks() - # set warning rank - rank_zero_only.rank = self.global_rank - self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None _init_dist_connection(self.cluster_environment, self._process_group_backend) @@ -190,11 +187,12 @@ def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) def set_world_ranks(self) -> None: - if self.cluster_environment is None: - return - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - rank_zero_only.rank = self.cluster_environment.global_rank() + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail + # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter + rank_zero_only.rank = self.global_rank def _configure_launcher(self) -> None: assert self.cluster_environment is not None diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index 36c675d2cac7a..e53b2b1c8d8c0 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -204,6 +204,9 @@ def setup_distributed(self) -> None: rank_zero_only.rank = self.global_rank def set_world_ranks(self) -> None: + # accessing global_rank will initialize the XLA computation client. since this is called outside of the spawned + # processes (by the accelerator connector), we cannot run the code that would normally be here. + # instead it's done in `setup_distributed` pass def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: From 4b685b59950d581780ca0dd2669ec8b450891ae2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 13 Apr 2023 16:41:13 +0200 Subject: [PATCH 4/6] CHANGELOG --- src/lightning/fabric/CHANGELOG.md | 6 ++++++ src/lightning/pytorch/CHANGELOG.md | 3 +++ 2 files changed, 9 insertions(+) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 60be083e34733..756a62454c24d 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) + +- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process ([#16966](https://github.com/Lightning-AI/lightning/pull/16966)) + ### Deprecated - @@ -39,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed issue where running on TPUs would select the wrong device index ([#17227](https://github.com/Lightning-AI/lightning/pull/17227)) +- Fixed issue where Fabric would not initialize the global rank, world size, and rank-zero-only rank after initialization and before launch ([#16966](https://github.com/Lightning-AI/lightning/pull/16966)) + + ## [2.0.1.post0] - 2023-04-11 No changes diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 16d2bb18a5e82..43191f00add8f 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) + +- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process ([#16966](https://github.com/Lightning-AI/lightning/pull/16966)) + ### Deprecated - From 9f42d7bc1ab38ddb064edf2bafdf44b3faf7f769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 13 Apr 2023 16:48:14 +0200 Subject: [PATCH 5/6] Typos --- src/lightning/fabric/plugins/environments/lsf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/plugins/environments/lsf.py b/src/lightning/fabric/plugins/environments/lsf.py index 53f2a61de9b84..8500a3e40fd93 100644 --- a/src/lightning/fabric/plugins/environments/lsf.py +++ b/src/lightning/fabric/plugins/environments/lsf.py @@ -88,7 +88,7 @@ def world_size(self) -> int: if world_size is None: raise ValueError( "Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found." - "Make sure you run your executable with `jsrun`." + " Make sure you run your executable with `jsrun`." ) return int(world_size) @@ -101,7 +101,7 @@ def global_rank(self) -> int: if global_rank is None: raise ValueError( "Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found." - "Make sure you run your executable with `jsrun`." + " Make sure you run your executable with `jsrun`." ) return int(global_rank) @@ -114,7 +114,7 @@ def local_rank(self) -> int: if local_rank is None: raise ValueError( "Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found." - "Make sure you run your executable with `jsrun`." + " Make sure you run your executable with `jsrun`." ) return int(local_rank) From 43449658517587486889e0573491509ca55ea210 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 13 Apr 2023 16:50:10 +0200 Subject: [PATCH 6/6] Reuse test --- tests/tests_fabric/test_connector.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index abae1aba7b5eb..681975a22addb 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -128,6 +128,8 @@ def creates_processes_externally(self) -> bool: assert isinstance(connector.accelerator, CPUAccelerator) assert isinstance(connector.strategy, DDPStrategy) assert isinstance(connector.strategy.cluster_environment, CustomCluster) + # this checks that `strategy._set_world_ranks` was called by the connector + assert connector.strategy.world_size == 2 @RunIf(mps=False) @@ -230,10 +232,10 @@ class Strat(DDPStrategy): @mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0) def test_fallback_from_ddp_spawn_to_ddp_on_cluster(_, __, env_vars, expected_environment): with mock.patch.dict(os.environ, env_vars, clear=True): - trainer = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2) - assert isinstance(trainer.accelerator, CPUAccelerator) - assert isinstance(trainer.strategy, DDPStrategy) - assert isinstance(trainer.strategy.cluster_environment, expected_environment) + connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2) + assert isinstance(connector.accelerator, CPUAccelerator) + assert isinstance(connector.strategy, DDPStrategy) + assert isinstance(connector.strategy.cluster_environment, expected_environment) @RunIf(mps=False)