Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Enable timeout in dist training #877

Merged
merged 10 commits into from
Feb 3, 2023
11 changes: 11 additions & 0 deletions docs/en/advanced_tutorials/distributed.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ We will detail on these APIs in the following chapters.

- [init_dist](mmengine.dist.init_dist): Launch function of distributed training. Currently it supports 3 launchers including pytorch, slurm and MPI. It also setup the given communication backends, defaults to NCCL.

If you need to change the runtime timeout (default=30 minutes) for distributed operations that take very long, you can specify a different timeout in your `env_cfg` configuration passing in [Runner](mmengine.runner.Runner) like this:

```python
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl', timeout=10800), # Sets the timeout to 3h (10800 seconds)
)
runner = Runner(xxx, env_cfg=env_cfg)
```

## Query and control

The query and control functions are all argument free.
Expand Down
14 changes: 14 additions & 0 deletions mmengine/dist/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import functools
import os
import subprocess
Expand Down Expand Up @@ -50,6 +51,19 @@ def init_dist(launcher, backend='nccl', **kwargs) -> None:
'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
timeout = kwargs.get('timeout', None)
if timeout is not None:
# If a timeout (in seconds) is specified, it must be converted
# to a timedelta object before forwarding the call to
# the respective backend, because they expect a timedelta object.
try:
kwargs['timeout'] = datetime.timedelta(seconds=timeout)
except TypeError as exception:
raise TypeError(
f'Timeout for distributed training must be provided as '
f"timeout in seconds, but we've received the type "
f'{type(timeout)}. Please specify the timeout like this: '
f"dist_cfg=dict(backend='nccl', timeout=1800)") from exception
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
Expand Down
2 changes: 1 addition & 1 deletion mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def setup_env(self, env_cfg: Dict) -> None:
mp_start_method='fork',
opencv_num_threads=0
),
dist_cfg=dict(backend='nccl'),
dist_cfg=dict(backend='nccl', timeout=1800),
resource_limit=4096
)

Expand Down