Skip to content

Commit

Permalink
Fix environment variable order for global rank determination (#11406)
Browse files Browse the repository at this point in the history

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
3 people authored and rohitgr7 committed Feb 17, 2022
1 parent 851d449 commit 60825c0
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827))


- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406))


## [1.5.9] - 2022-01-20

### Fixed
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/utilities/rank_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:

# TODO: this should be part of the cluster environment
def _get_rank() -> int:
rank_keys = ("RANK", "SLURM_PROCID", "LOCAL_RANK")
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
Expand Down
16 changes: 16 additions & 0 deletions tests/utilities/rank_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,19 @@ def foo():

x = foo()
assert x is None


@pytest.mark.parametrize(
"environ,expected_rank",
[
({"SLURM_PROCID": "2"}, 2),
({"SLURM_PROCID": "2", "LOCAL_RANK": "1"}, 1),
({"SLURM_PROCID": "2", "LOCAL_RANK": "1", "RANK": "0"}, 0),
],
)
def test_rank_zero_priority(environ, expected_rank):
"""Test the priority in which the rank gets determined when multiple environment variables are available."""
with mock.patch.dict(os.environ, environ):
from pytorch_lightning.utilities.rank_zero import _get_rank

assert _get_rank() == expected_rank

0 comments on commit 60825c0

Please sign in to comment.