diff --git a/CHANGELOG.md b/CHANGELOG.md index e5bf0c690f545..dfebe17f25a7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -615,6 +615,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827)) +- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406)) + + ## [1.5.9] - 2022-01-20 ### Fixed diff --git a/pytorch_lightning/utilities/rank_zero.py b/pytorch_lightning/utilities/rank_zero.py index df1e6792085c0..513798ff7aae7 100644 --- a/pytorch_lightning/utilities/rank_zero.py +++ b/pytorch_lightning/utilities/rank_zero.py @@ -37,7 +37,9 @@ def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: # TODO: this should be part of the cluster environment def _get_rank() -> int: - rank_keys = ("RANK", "SLURM_PROCID", "LOCAL_RANK") + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID") for key in rank_keys: rank = os.environ.get(key) if rank is not None: diff --git a/tests/utilities/rank_zero.py b/tests/utilities/rank_zero.py index 61bcf61c0ca59..15a55fdd877f2 100644 --- a/tests/utilities/rank_zero.py +++ b/tests/utilities/rank_zero.py @@ -53,3 +53,19 @@ def foo(): x = foo() assert x is None + + +@pytest.mark.parametrize( + "environ,expected_rank", + [ + ({"SLURM_PROCID": "2"}, 2), + ({"SLURM_PROCID": "2", "LOCAL_RANK": "1"}, 1), + ({"SLURM_PROCID": "2", "LOCAL_RANK": "1", "RANK": "0"}, 0), + ], +) +def test_rank_zero_priority(environ, expected_rank): + """Test the priority in which the rank gets determined when multiple environment variables are available.""" + with mock.patch.dict(os.environ, environ): + from pytorch_lightning.utilities.rank_zero import _get_rank + + assert _get_rank() == expected_rank