-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add ClusterEnvironment for LSF systems * update init file * add available cluster environments * clean up LSFEnvironment * add ddp_hpc as a distributed backend * clean up SLURMEnvironment * remove extra blank line * init device for DDPHPCAccelerator We need to do this so we don't send the model to the same device from multiple ranks * committing current state * add additional methods to ClusterEnvironments * add NVIDIA mixin for setting up CUDA envars * remove troubleshooting prints * cleanup SLURMEnvironment * fix docstring * cleanup TorchElasticEnvironment and add documentation * PEP8 puts a cork in it * add set_ranks_to_trainer * remove unused import * move to new location * update LSF environment * remove mixin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changelog * reset slurm env * add tests * add licence * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test node_rank * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add lsf env to docs * add auto detection for lsf environment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix is_using_lsf() and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
1b06edf
commit 3102922
Showing
10 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
160 changes: 160 additions & 0 deletions
160
pytorch_lightning/plugins/environments/lsf_environment.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import socket | ||
|
||
from pytorch_lightning import _logger as log | ||
from pytorch_lightning.plugins.environments import ClusterEnvironment | ||
|
||
|
||
class LSFEnvironment(ClusterEnvironment): | ||
""" | ||
An environment for running on clusters managed by the LSF resource manager. | ||
It is expected that any execution using this ClusterEnvironment was executed | ||
using the Job Step Manager i.e. ``jsrun``. | ||
This plugin expects the following environment variables. | ||
LSB_JOBID: | ||
The LSF assigned job ID | ||
LSB_HOSTS: | ||
The hosts used in the job. This string is expected to have the format "batch <rank_0_host> ...." | ||
JSM_NAMESPACE_LOCAL_RANK: | ||
The node local rank for the task. This environment variable is set by jsrun | ||
JSM_NAMESPACE_SIZE: | ||
The world size for the task. This environment variable is set by jsrun | ||
""" | ||
|
||
def __init__(self): | ||
self._master_address = self._get_master_address() | ||
self._master_port = self._get_master_port() | ||
log.debug(f"MASTER_ADDR: {self._master_address}") | ||
log.debug(f"MASTER_PORT: {self._master_port}") | ||
|
||
@staticmethod | ||
def is_using_lsf() -> bool: | ||
""" Returns ``True`` if the current process was launched using the jsrun command. """ | ||
required_env_vars = ( | ||
"LSB_JOBID", | ||
"LSB_HOSTS", | ||
"JSM_NAMESPACE_LOCAL_RANK", | ||
"JSM_NAMESPACE_SIZE", | ||
) | ||
return all(v in os.environ for v in required_env_vars) | ||
|
||
def creates_children(self) -> bool: | ||
return True | ||
|
||
def master_address(self): | ||
""" The master address is read from a list of hosts contained in the environment variable `LSB_HOSTS`. """ | ||
return self._master_address | ||
|
||
def master_port(self): | ||
""" THe master port gets calculated from the LSF job ID. """ | ||
return self._master_port | ||
|
||
def world_size(self): | ||
""" The world size is read from the environment variable `JSM_NAMESPACE_SIZE`. """ | ||
var = "JSM_NAMESPACE_SIZE" | ||
world_size = os.environ.get(var) | ||
if world_size is None: | ||
raise ValueError( | ||
f"Cannot determine world size from environment variable {var}." | ||
" Make sure you run your executable with `jsrun`" | ||
) | ||
return int(world_size) | ||
|
||
def set_world_size(self, size: int) -> None: | ||
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") | ||
|
||
def global_rank(self): | ||
""" The world size is read from the environment variable `JSM_NAMESPACE_RANK`. """ | ||
var = "JSM_NAMESPACE_RANK" | ||
global_rank = os.environ.get(var) | ||
if global_rank is None: | ||
raise ValueError( | ||
f"Cannot determine global rank from environment variable {var}." | ||
" Make sure you run your executable with `jsrun`" | ||
) | ||
return int(global_rank) | ||
|
||
def set_global_rank(self, rank: int) -> None: | ||
log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.") | ||
|
||
def local_rank(self): | ||
""" The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`. """ | ||
var = "JSM_NAMESPACE_LOCAL_RANK" | ||
local_rank = os.environ.get(var) | ||
if local_rank is None: | ||
raise ValueError( | ||
f"Cannot determine local rank from environment variable {var}." | ||
" Make sure you run your executable with `jsrun`" | ||
) | ||
return int(local_rank) | ||
|
||
def node_rank(self): | ||
""" | ||
The node rank is determined by the position of the current hostname in the list of hosts stored in | ||
the environment variable `LSB_HOSTS`. | ||
""" | ||
hosts = self._read_hosts() | ||
count = dict() | ||
for host in hosts: | ||
if "batch" in host or "login" in host: | ||
continue | ||
if host not in count: | ||
count[host] = len(count) | ||
return count[socket.gethostname()] | ||
|
||
@staticmethod | ||
def _read_hosts(): | ||
hosts = os.environ.get("LSB_HOSTS") | ||
if not hosts: | ||
raise ValueError("Could not find hosts in environment variable LSB_HOSTS") | ||
hosts = hosts.split() | ||
if len(hosts) < 2: | ||
raise ValueError( | ||
"Cannot parse hosts from LSB_HOSTS environment variable." | ||
" Expected format: \"batch <rank_0_host> ...\"" | ||
) | ||
return hosts | ||
|
||
def _get_master_address(self): | ||
hosts = self._read_hosts() | ||
return hosts[1] | ||
|
||
@staticmethod | ||
def _get_master_port(): | ||
""" | ||
A helper function for accessing the master port. | ||
Uses the LSF job ID so all ranks can compute the master port. | ||
""" | ||
# check for user-specified master port | ||
port = os.environ.get("MASTER_PORT") | ||
if not port: | ||
jobid = os.environ.get("LSB_JOBID") | ||
if not jobid: | ||
raise ValueError("Could not find job id in environment variable LSB_JOBID") | ||
port = int(jobid) | ||
# all ports should be in the 10k+ range | ||
port = int(port) % 1000 + 10000 | ||
log.debug(f"calculated LSF master port: {port}") | ||
else: | ||
log.debug(f"using externally specified master port: {port}") | ||
return int(port) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
from unittest import mock | ||
|
||
import pytest | ||
|
||
from pytorch_lightning.plugins.environments import LSFEnvironment | ||
|
||
|
||
@mock.patch.dict(os.environ, { | ||
"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", | ||
"LSB_JOBID": "1234", | ||
}) | ||
def test_missing_lsb_hosts(): | ||
""" Test an error when the lsb hosts list cannot be found. """ | ||
del os.environ["LSB_HOSTS"] | ||
with pytest.raises(ValueError, match="Could not find hosts in environment variable LSB_HOSTS"): | ||
LSFEnvironment() | ||
|
||
|
||
@mock.patch.dict(os.environ, { | ||
"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", | ||
"LSB_JOBID": "1234", | ||
}) | ||
def test_missing_lsb_job_id(): | ||
""" Test an error when the job id cannot be found. """ | ||
del os.environ["LSB_JOBID"] | ||
with pytest.raises(ValueError, match="Could not find job id in environment variable LSB_JOBID"): | ||
LSFEnvironment() | ||
|
||
|
||
@mock.patch.dict( | ||
os.environ, { | ||
"MASTER_PORT": "4321", | ||
"LSB_JOBID": "1234", | ||
"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", | ||
} | ||
) | ||
def test_manual_master_port_and_address(): | ||
""" Test a user can set the port manually through the MASTER_PORT env variable. """ | ||
env = LSFEnvironment() | ||
assert env.master_port() == 4321 | ||
|
||
|
||
@mock.patch.dict( | ||
os.environ, { | ||
"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1 10.10.10.2 10.10.10.3", | ||
"LSB_JOBID": "1234", | ||
"JSM_NAMESPACE_SIZE": "4", | ||
"JSM_NAMESPACE_RANK": "3", | ||
"JSM_NAMESPACE_LOCAL_RANK": "1" | ||
} | ||
) | ||
def test_attributes_from_environment_variables(): | ||
""" Test that the LSF environment takes the attributes from the environment variables. """ | ||
env = LSFEnvironment() | ||
assert env.creates_children() | ||
assert env.master_address() == "10.10.10.0" | ||
assert env.master_port() == 10234 | ||
assert env.world_size() == 4 | ||
assert env.global_rank() == 3 | ||
assert env.local_rank() == 1 | ||
env.set_global_rank(100) | ||
assert env.global_rank() == 3 | ||
env.set_world_size(100) | ||
assert env.world_size() == 4 | ||
assert LSFEnvironment.is_using_lsf() | ||
|
||
|
||
@mock.patch("socket.gethostname", return_value="host2") | ||
@mock.patch.dict(os.environ, { | ||
"LSB_HOSTS": "batch host0 host1 host2 host3", | ||
"LSB_JOBID": "1234", | ||
}) | ||
def test_node_rank(_): | ||
env = LSFEnvironment() | ||
assert env.node_rank() == 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
13 changes: 13 additions & 0 deletions
13
tests/plugins/environments/test_torchelastic_environment.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters