diff --git a/CHANGELOG.md b/CHANGELOG.md index db1f3970e0e6f2..08081e1dd76aa7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) +- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) + + - Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) diff --git a/pytorch_lightning/plugins/environments/__init__.py b/pytorch_lightning/plugins/environments/__init__.py index 10d9bf50a4b844..70c1f8da90f13d 100644 --- a/pytorch_lightning/plugins/environments/__init__.py +++ b/pytorch_lightning/plugins/environments/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401 diff --git a/pytorch_lightning/plugins/environments/cluster_environment.py b/pytorch_lightning/plugins/environments/cluster_environment.py index c9e054c0328043..f3fb2fbeabaa2b 100644 --- a/pytorch_lightning/plugins/environments/cluster_environment.py +++ b/pytorch_lightning/plugins/environments/cluster_environment.py @@ -11,24 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod +from typing import Optional -class ClusterEnvironment: +class ClusterEnvironment(ABC): + """ Specification of a cluster environment. """ - def __init__(self): - self._world_size = None + @abstractmethod + def creates_children(self) -> bool: + """ Whether the environment creates the subprocesses or not. """ - def master_address(self): - pass + @abstractmethod + def master_address(self) -> str: + """ The master address through which all processes connect and communicate. """ - def master_port(self): - pass + @abstractmethod + def master_port(self) -> int: + """ An open and configured port in the master node through which all processes communicate. """ - def world_size(self) -> int: - return self._world_size + @abstractmethod + def world_size(self) -> Optional[int]: + """ The number of processes across all devices and nodes. """ + @abstractmethod def local_rank(self) -> int: - pass + """ The rank (index) of the currently running process inside of the current node. """ + @abstractmethod def node_rank(self) -> int: - pass + """ The rank (index) of the node on which the current process runs. """ diff --git a/pytorch_lightning/plugins/environments/lightning_environment.py b/pytorch_lightning/plugins/environments/lightning_environment.py new file mode 100644 index 00000000000000..6b71122b065bf2 --- /dev/null +++ b/pytorch_lightning/plugins/environments/lightning_environment.py @@ -0,0 +1,71 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import socket +from typing import Optional + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment + + +class LightningEnvironment(ClusterEnvironment): + """ + The default environment used by Lightning for a single node or free cluster (not managed). + + The master process must be launched by the user and Lightning will spawn new + worker processes for distributed training, either in a single node or across multiple nodes. + + If the master address and port are not provided, the default environment will choose them + automatically. It is recommended to use this default environment for single-node distributed + training as it provides the most convenient way to launch the training script. + """ + + def __init__(self): + super().__init__() + self._master_port = None + + def creates_children(self) -> bool: + return False + + def master_address(self) -> str: + return os.environ.get("MASTER_ADDR", "127.0.0.1") + + def master_port(self) -> int: + if self._master_port is None: + self._master_port = os.environ.get("MASTER_PORT", find_free_network_port()) + return int(self._master_port) + + def world_size(self) -> Optional[int]: + return None + + def local_rank(self) -> int: + return int(os.environ.get("LOCAL_RANK", 0)) + + def node_rank(self) -> int: + group_rank = os.environ.get("GROUP_RANK", 0) + return int(os.environ.get("NODE_RANK", group_rank)) + + +def find_free_network_port() -> int: + """ + Finds a free port on localhost. + It is useful in single-node training when we don't want to connect to a real master node but + have to set the `MASTER_PORT` environment variable. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + s.close() + return port diff --git a/pytorch_lightning/plugins/environments/slurm_environment.py b/pytorch_lightning/plugins/environments/slurm_environment.py index 7f9586cab0aceb..3cba5d101a1598 100644 --- a/pytorch_lightning/plugins/environments/slurm_environment.py +++ b/pytorch_lightning/plugins/environments/slurm_environment.py @@ -26,7 +26,10 @@ class SLURMEnvironment(ClusterEnvironment): def __init__(self): super().__init__() - def master_address(self): + def creates_children(self) -> bool: + return True + + def master_address(self) -> str: # figure out the root node addr slurm_nodelist = os.environ.get("SLURM_NODELIST") if slurm_nodelist: @@ -39,7 +42,7 @@ def master_address(self): log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") return root_node - def master_port(self): + def master_port(self) -> int: # ----------------------- # SLURM JOB = PORT number # ----------------------- @@ -64,18 +67,18 @@ def master_port(self): log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - return default_port + return int(default_port) def world_size(self): - return self._world_size + return None - def local_rank(self): + def local_rank(self) -> int: return int(os.environ['SLURM_LOCALID']) - def node_rank(self): + def node_rank(self) -> int: return int(os.environ['SLURM_NODEID']) - def resolve_root_node_address(self, root_node): + def resolve_root_node_address(self, root_node: str) -> str: if '[' in root_node: name, numbers = root_node.split('[', maxsplit=1) number = numbers.split(',', maxsplit=1)[0] diff --git a/pytorch_lightning/plugins/environments/torchelastic_environment.py b/pytorch_lightning/plugins/environments/torchelastic_environment.py index 5ac7d9f1c9a40a..c3a59fbfd75bc1 100644 --- a/pytorch_lightning/plugins/environments/torchelastic_environment.py +++ b/pytorch_lightning/plugins/environments/torchelastic_environment.py @@ -14,6 +14,7 @@ import logging import os +from typing import Optional from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.utilities import rank_zero_warn @@ -26,7 +27,10 @@ class TorchElasticEnvironment(ClusterEnvironment): def __init__(self): super().__init__() - def master_address(self): + def creates_children(self) -> bool: + return True + + def master_address(self) -> str: if "MASTER_ADDR" not in os.environ: rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost") os.environ["MASTER_ADDR"] = "127.0.0.1" @@ -34,19 +38,20 @@ def master_address(self): master_address = os.environ.get('MASTER_ADDR') return master_address - def master_port(self): + def master_port(self) -> int: if "MASTER_PORT" not in os.environ: rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910") os.environ["MASTER_PORT"] = "12910" log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - port = os.environ.get('MASTER_PORT') + port = int(os.environ.get('MASTER_PORT')) return port - def world_size(self): - return os.environ.get('WORLD_SIZE') + def world_size(self) -> Optional[int]: + world_size = os.environ.get('WORLD_SIZE') + return int(world_size) if world_size is not None else world_size - def local_rank(self): + def local_rank(self) -> int: return int(os.environ['LOCAL_RANK']) def node_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 748dcdc9e6b680..3e6c618fcf4e22 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -30,12 +30,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn -from pytorch_lightning.utilities.distributed import ( - find_free_network_port, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything @@ -43,7 +38,6 @@ from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path - log = logging.getLogger(__name__) @@ -90,8 +84,7 @@ def setup(self, model): self._model = model # start the other scripts - # TODO: refactor and let generic cluster env hold the information about who spawns the processes - if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": + if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": self._call_children_scripts() # set the task idx @@ -105,15 +98,12 @@ def _call_children_scripts(self): self._has_spawned_children = True # DDP Environment variables - os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") - os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port())) + os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) # allow the user to pass the node rank - node_rank = "0" - node_rank = os.environ.get("NODE_RANK", node_rank) - node_rank = os.environ.get("GROUP_RANK", node_rank) - os.environ["NODE_RANK"] = node_rank - os.environ["LOCAL_RANK"] = "0" + os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) + os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) # when user is using hydra find the absolute path path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path @@ -209,7 +199,6 @@ def determine_ddp_device_ids(self): return [self.root_device.index] def init_ddp_connection(self, global_rank: int, world_size: int) -> None: - # TODO: From where to get cluster environment? os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 9ff4bb8cd27491..d699dcb690d888 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -30,13 +30,7 @@ from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import ( - find_free_network_port, - rank_zero_only, - rank_zero_warn, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything log = logging.getLogger(__name__) @@ -84,7 +78,7 @@ def distributed_sampler_kwargs(self): def setup(self, model): self._model = model - os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port())) + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) # pass in a state q smp = mp.get_context("spawn") @@ -93,7 +87,7 @@ def setup(self, model): def set_world_ranks(self, process_idx): self.local_rank = process_idx self.node_rank = self.cluster_environment.node_rank() - self.task_idx = self.cluster_local_rank + self.task_idx = self.cluster_environment.local_rank() self.global_rank = self.node_rank * self.num_processes + self.local_rank self.world_size = self.num_nodes * self.num_processes diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index f3c825fe9cd7aa..715c5332e231c0 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -40,13 +40,6 @@ def __init__( self.local_rank = 0 self.cluster_environment = cluster_environment - @property - def cluster_local_rank(self): - try: - return self.cluster_environment.local_rank() - except KeyError: - return 0 - @property @abstractmethod def root_device(self): diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 67da309b263ccb..99d716f6b5a8c2 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -42,7 +42,12 @@ TPUSpawnPlugin, TrainingTypePlugin, ) -from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment +from pytorch_lightning.plugins.environments import ( + ClusterEnvironment, + LightningEnvironment, + SLURMEnvironment, + TorchElasticEnvironment, +) from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.utilities import ( _APEX_AVAILABLE, @@ -451,17 +456,10 @@ def select_cluster_environment(self) -> ClusterEnvironment: return self._cluster_environment if self.is_slurm_managing_tasks: env = SLURMEnvironment() - # TODO: decouple DDP from SLURM - # refactor and let generic cluster env hold the information about who spawns the processes - os.environ["PL_IN_DDP_SUBPROCESS"] = "1" elif self.is_using_torchelastic: env = TorchElasticEnvironment() - # TODO: decouple DDP from TE - # refactor and let generic cluster env hold the information about who spawns the processes - os.environ["PL_IN_DDP_SUBPROCESS"] = "1" else: - # TODO: maybe introduce a DefaultEnvironment? - env = TorchElasticEnvironment() + env = LightningEnvironment() return env def set_distributed_mode(self, distributed_backend: Optional[str] = None): diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index cd3e74e35a8f7c..e797c32bbf917f 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -64,21 +64,6 @@ def _debug(*args, **kwargs): rank_zero_warn = rank_zero_only(_warn) -def find_free_network_port() -> int: - """ - Finds a free port on localhost. - It is useful in single-node training when we don't want to connect to a real master node but - have to set the `MASTER_PORT` environment variable. - """ - import socket - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 0)) - s.listen(1) - port = s.getsockname()[1] - s.close() - return port - - def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): """ Function to gather all tensors from several ddp processes onto a list that diff --git a/setup.cfg b/setup.cfg index c845499e45304b..9dd06c8b3daa93 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,6 @@ omit = pytorch_lightning/accelerators/ddp2_*.py pytorch_lightning/accelerators/dp_*.py pytorch_lightning/accelerators/tpu_*.py - pytorch_lightning/cluster_environments/*.py pytorch_lightning/utilities/xla_device_utils.py pytorch_lightning/utilities/distributed.py pytorch_lightning/tuner/auto_gpu_select.py diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 3c6d7a094c11fe..42c910cb8078b7 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -33,7 +33,7 @@ PrecisionPlugin, SingleDevicePlugin, ) -from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment +from pytorch_lightning.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -55,7 +55,7 @@ def test_accelerator_choice_ddp_cpu(tmpdir): ) assert isinstance(trainer.accelerator, CPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) - assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @@ -69,7 +69,7 @@ def test_accelerator_choice_ddp(cuda_available_mock, device_count_mock): ) assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) - assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @@ -83,7 +83,7 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock): ) assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin) - assert isinstance(trainer.training_type_plugin.cluster_environment, TorchElasticEnvironment) + assert isinstance(trainer.training_type_plugin.cluster_environment, LightningEnvironment) @RunIf(min_gpus=2) @@ -297,7 +297,7 @@ def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock): Test that we choose the custom cluster even when SLURM or TE flags are around """ - class CustomCluster(ClusterEnvironment): + class CustomCluster(LightningEnvironment): def master_address(self): return 'asdf' diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 4a2c76ea8f4cad..39f5e0dca50759 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -24,7 +24,7 @@ ) from pytorch_lightning.overrides.distributed import LightningDistributedModule from pytorch_lightning.plugins import DDPSpawnPlugin -from pytorch_lightning.plugins.environments import TorchElasticEnvironment +from pytorch_lightning.plugins.environments import LightningEnvironment from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -188,7 +188,7 @@ def test_v1_4_0_deprecated_lightning_distributed_data_parallel(tmpdir): plugins=[ CustomDDPPlugin( parallel_devices=[torch.device("cuda", 0), torch.device("cuda", 1)], - cluster_environment=TorchElasticEnvironment(), + cluster_environment=LightningEnvironment(), ) ] ) diff --git a/tests/models/test_sync_batchnorm.py b/tests/models/test_sync_batchnorm.py index 42d95d4f21aded..5750bb66a75b63 100644 --- a/tests/models/test_sync_batchnorm.py +++ b/tests/models/test_sync_batchnorm.py @@ -19,7 +19,7 @@ from pytorch_lightning import LightningModule, seed_everything, Trainer from pytorch_lightning.plugins import DDPSpawnPlugin -from pytorch_lightning.plugins.environments import TorchElasticEnvironment +from pytorch_lightning.plugins.environments import LightningEnvironment from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import FLOAT16_EPSILON from tests.helpers.datamodules import MNISTDataModule @@ -109,7 +109,7 @@ def test_sync_batchnorm_ddp(tmpdir): parallel_devices=[torch.device("cuda", 0), torch.device("cuda", 1)], num_nodes=1, sync_batchnorm=True, - cluster_environment=TorchElasticEnvironment(), + cluster_environment=LightningEnvironment(), find_unused_parameters=True ) diff --git a/tests/plugins/environments/__init__.py b/tests/plugins/environments/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/plugins/environments/test_lightning_environment.py b/tests/plugins/environments/test_lightning_environment.py new file mode 100644 index 00000000000000..83d26cb0fcf91c --- /dev/null +++ b/tests/plugins/environments/test_lightning_environment.py @@ -0,0 +1,52 @@ +import os +from unittest import mock + +from pytorch_lightning.plugins.environments import LightningEnvironment + + +@mock.patch.dict(os.environ, {}) +def test_default_attributes(): + """ Test the default attributes when no environment variables are set. """ + env = LightningEnvironment() + assert not env.creates_children() + assert env.master_address() == "127.0.0.1" + assert isinstance(env.master_port(), int) + assert env.world_size() is None + assert env.local_rank() == 0 + assert env.node_rank() == 0 + + +@mock.patch.dict(os.environ, { + "MASTER_ADDR": "1.2.3.4", + "MASTER_PORT": "500", + "LOCAL_RANK": "2", + "NODE_RANK": "3", +}) +def test_attributes_from_environment_variables(): + """ Test that the default cluster environment takes the attributes from the environment variables. """ + env = LightningEnvironment() + assert env.master_address() == "1.2.3.4" + assert env.master_port() == 500 + assert env.world_size() is None + assert env.local_rank() == 2 + assert env.node_rank() == 3 + + +@mock.patch.dict(os.environ, { + "GROUP_RANK": "1", +}) +def test_node_rank_from_group_rank(): + """ Test that the GROUP_RANK substitutes NODE_RANK. """ + env = LightningEnvironment() + assert "NODE_RANK" not in os.environ + assert env.node_rank() == 1 + + +@mock.patch.dict(os.environ, {}) +def test_random_master_port(): + """ Test randomly chosen master port when no master port was given by user. """ + env = LightningEnvironment() + port = env.master_port() + assert isinstance(port, int) + # repeated calls do not generate a new port number + assert env.master_port() == port diff --git a/tests/plugins/environments/test_slurm_environment.py b/tests/plugins/environments/test_slurm_environment.py new file mode 100644 index 00000000000000..8e82434846e68b --- /dev/null +++ b/tests/plugins/environments/test_slurm_environment.py @@ -0,0 +1,55 @@ +import os +from unittest import mock + +import pytest + +from pytorch_lightning.plugins.environments import SLURMEnvironment + + +@mock.patch.dict(os.environ, {}) +def test_default_attributes(): + """ Test the default attributes when no environment variables are set. """ + env = SLURMEnvironment() + assert env.creates_children() + assert env.master_address() == "127.0.0.1" + assert env.master_port() == 12910 + assert env.world_size() is None + with pytest.raises(KeyError): + # local rank is required to be passed as env variable + env.local_rank() + with pytest.raises(KeyError): + # node_rank is required to be passed as env variable + env.node_rank() + + +@mock.patch.dict( + os.environ, { + "SLURM_NODELIST": "1.1.1.1, 1.1.1.2", + "SLURM_JOB_ID": "0001234", + "WORLD_SIZE": "20", + "SLURM_LOCALID": "2", + "SLURM_NODEID": "3", + } +) +def test_attributes_from_environment_variables(): + """ Test that the SLURM cluster environment takes the attributes from the environment variables. """ + env = SLURMEnvironment() + assert env.master_address() == "1.1.1.1" + assert env.master_port() == 15000 + 1234 + assert env.world_size() is None + assert env.local_rank() == 2 + assert env.node_rank() == 3 + + +@pytest.mark.parametrize( + "slurm_node_list,expected", [ + ("alpha,beta,gamma", "alpha"), + ("alpha beta gamma", "alpha"), + ("1.2.3.[100-110]", "1.2.3.100"), + ] +) +def test_master_address_from_slurm_node_list(slurm_node_list, expected): + """ Test extracting the master node from different formats for the SLURM_NODELIST. """ + with mock.patch.dict(os.environ, {"SLURM_NODELIST": slurm_node_list}): + env = SLURMEnvironment() + assert env.master_address() == expected diff --git a/tests/plugins/environments/test_torchelastic_environment.py b/tests/plugins/environments/test_torchelastic_environment.py new file mode 100644 index 00000000000000..55cfc25adde3c9 --- /dev/null +++ b/tests/plugins/environments/test_torchelastic_environment.py @@ -0,0 +1,39 @@ +import os +from unittest import mock + +import pytest + +from pytorch_lightning.plugins.environments import TorchElasticEnvironment + + +@mock.patch.dict(os.environ, {}) +def test_default_attributes(): + """ Test the default attributes when no environment variables are set. """ + env = TorchElasticEnvironment() + assert env.creates_children() + assert env.master_address() == "127.0.0.1" + assert env.master_port() == 12910 + assert env.world_size() is None + with pytest.raises(KeyError): + # local rank is required to be passed as env variable + env.local_rank() + assert env.node_rank() == 0 + + +@mock.patch.dict( + os.environ, { + "MASTER_ADDR": "1.2.3.4", + "MASTER_PORT": "500", + "WORLD_SIZE": "20", + "LOCAL_RANK": "2", + "GROUP_RANK": "3", + } +) +def test_attributes_from_environment_variables(): + """ Test that the torchelastic cluster environment takes the attributes from the environment variables. """ + env = TorchElasticEnvironment() + assert env.master_address() == "1.2.3.4" + assert env.master_port() == 500 + assert env.world_size() == 20 + assert env.local_rank() == 2 + assert env.node_rank() == 3