Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes around Strategy.set_world_ranks #16966

Merged
merged 7 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))


- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process ([#16966](https://github.com/Lightning-AI/lightning/pull/16966))

### Deprecated

-
Expand All @@ -39,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed issue where running on TPUs would select the wrong device index ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))


- Fixed issue where Fabric would not initialize the global rank, world size, and rank-zero-only rank after initialization and before launch ([#16966](https://github.com/Lightning-AI/lightning/pull/16966))


## [2.0.1.post0] - 2023-04-11

No changes
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,8 @@ def _lazy_init_strategy(self) -> None:
self.strategy.parallel_devices = self._parallel_devices
if hasattr(self.strategy, "num_nodes"):
self.strategy._num_nodes = self._num_nodes_flag
if hasattr(self.strategy, "set_world_ranks"):
self.strategy.set_world_ranks()
if hasattr(self.strategy, "_set_world_ranks"):
self.strategy._set_world_ranks()
self.strategy._configure_launcher()

if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible:
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/fabric/plugins/environments/lsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def world_size(self) -> int:
if world_size is None:
raise ValueError(
"Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found."
"Make sure you run your executable with `jsrun`."
" Make sure you run your executable with `jsrun`."
)
return int(world_size)

Expand All @@ -101,7 +101,7 @@ def global_rank(self) -> int:
if global_rank is None:
raise ValueError(
"Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found."
"Make sure you run your executable with `jsrun`."
" Make sure you run your executable with `jsrun`."
)
return int(global_rank)

Expand All @@ -114,7 +114,7 @@ def local_rank(self) -> int:
if local_rank is None:
raise ValueError(
"Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found."
"Make sure you run your executable with `jsrun`."
" Make sure you run your executable with `jsrun`."
)
return int(local_rank)

Expand Down
12 changes: 6 additions & 6 deletions src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:

def _setup_distributed(self) -> None:
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
Expand All @@ -186,11 +185,12 @@ def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

def _set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank

def _determine_ddp_device_ids(self) -> Optional[List[int]]:
if self.root_device.type == "cpu":
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.distributed import log
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH

Expand Down Expand Up @@ -580,7 +580,6 @@ def _setup_distributed(self) -> None:
)
reset_seed()
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
self._init_deepspeed_distributed()
if not self._config_initialized:
self._format_config()
Expand Down
12 changes: 6 additions & 6 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
def _setup_distributed(self) -> None:
reset_seed()
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
Expand All @@ -329,11 +328,12 @@ def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

def _set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank


def _setup_activation_checkpointing(module: "FullyShardedDataParallel", layers: List[Type[Module]]) -> None:
Expand Down
6 changes: 0 additions & 6 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def _configure_launcher(self) -> None:

def setup_environment(self) -> None:
self._launched = True
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
super().setup_environment()

Expand Down Expand Up @@ -203,8 +202,3 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register("xla", cls, description=cls.__class__.__name__)

def _set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
rank_zero_only.rank = self.cluster_environment.global_rank()
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))


- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process ([#16966](https://github.com/Lightning-AI/lightning/pull/16966))

### Deprecated

-
Expand Down
12 changes: 6 additions & 6 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def setup_distributed(self) -> None:
log.debug(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()
self.set_world_ranks()
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
Expand All @@ -192,11 +191,12 @@ def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank

def _register_ddp_hooks(self) -> None:
log.debug(f"{self.__class__.__name__}: registering ddp hooks")
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn, WarningCache
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn, WarningCache
from lightning.pytorch.utilities.types import LRSchedulerConfig, STEP_OUTPUT

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -326,7 +326,6 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option
def setup_distributed(self) -> None:
reset_seed()
self.set_world_ranks()
rank_zero_only.rank = self.global_rank
self._init_deepspeed_distributed()
if not self._config_initialized:
self._format_config()
Expand Down
14 changes: 6 additions & 8 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def setup_environment(self) -> None:
# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend)
Expand All @@ -190,11 +187,12 @@ def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank

def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/strategies/xla.py
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,13 @@ def reduce(

def setup_distributed(self) -> None:
self._launched = True
self.set_world_ranks()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the line below already covers what was done in this method.

rank_zero_only.rank = self.global_rank

def set_world_ranks(self) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The accelerator connector will call this method. But we don't want the XLA strategy setting the rank (yet). Since XLAStrategy inherits from DDPStrategy, we need to leave it empty here.

if self.cluster_environment is None:
return
rank_zero_only.rank = self.cluster_environment.global_rank()
# accessing global_rank will initialize the XLA computation client. since this is called outside of the spawned
# processes (by the accelerator connector), we cannot run the code that would normally be here.
# instead it's done in `setup_distributed`
pass

def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
Expand Down
10 changes: 6 additions & 4 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def creates_processes_externally(self) -> bool:
assert isinstance(connector.accelerator, CPUAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.cluster_environment, CustomCluster)
# this checks that `strategy._set_world_ranks` was called by the connector
assert connector.strategy.world_size == 2


@RunIf(mps=False)
Expand Down Expand Up @@ -230,10 +232,10 @@ class Strat(DDPStrategy):
@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0)
def test_fallback_from_ddp_spawn_to_ddp_on_cluster(_, __, env_vars, expected_environment):
with mock.patch.dict(os.environ, env_vars, clear=True):
trainer = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.strategy, DDPStrategy)
assert isinstance(trainer.strategy.cluster_environment, expected_environment)
connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2)
assert isinstance(connector.accelerator, CPUAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.cluster_environment, expected_environment)


@RunIf(mps=False)
Expand Down