Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] MMEngine doesn't allow to set a timeout for distributed training #873

Closed
2 tasks done
apacha opened this issue Jan 12, 2023 · 2 comments · Fixed by #877
Closed
2 tasks done

[Bug] MMEngine doesn't allow to set a timeout for distributed training #873

apacha opened this issue Jan 12, 2023 · 2 comments · Fixed by #877
Assignees
Labels
bug Something isn't working

Comments

@apacha
Copy link
Contributor

apacha commented Jan 12, 2023

Prerequisite

Environment

  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.2.3 (Git Hash 7336ca9f055cf1bfa13efb658fe15dc9b41f0740)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.3
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86
  - CuDNN 8.2
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.2.0, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.10.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, \n'), ('TorchVision', '0.11.1+cu113'), ('OpenCV', '4.7.0'), ('MMEngine', '0.4.0')])

Reproduces the problem - code sample

Specify a timeout in the default_runtime.py of the MMOCR project

env_cfg = dict(
    cudnn_benchmark=True,
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    dist_cfg=dict(backend='nccl', timeout="10800"),
)

Reproduces the problem - command or script

Run the training via train.py on the 1.x branch of the MMOCR project as described in the documentation.

Reproduces the problem - error message

Upon hitting

runner = Runner.from_cfg(cfg)

Application will crash with

Traceback (most recent call last):
  File "tools/train.py", line 117, in <module>
    self.setup_env(env_cfg)
  File "/usr/local/lib/python3.8/site-packages/mmengine/runner/runner.py", line 644, in setup_env
    runner = cls(
  File "/usr/local/lib/python3.8/site-packages/mmengine/runner/runner.py", line 345, in __init__
    main()
    self.setup_env(env_cfg)  File "tools/train.py", line 106, in main

  File "/usr/local/lib/python3.8/site-packages/mmengine/runner/runner.py", line 644, in setup_env
    runner = Runner.from_cfg(cfg)
  File "/usr/local/lib/python3.8/site-packages/mmengine/runner/runner.py", line 431, in from_cfg
    init_dist(self.launcher, **dist_cfg)
  File "/usr/local/lib/python3.8/site-packages/mmengine/dist/utils.py", line 56, in init_dist
    self.setup_env(env_cfg)
  File "/usr/local/lib/python3.8/site-packages/mmengine/runner/runner.py", line 644, in setup_env
    _init_dist_pytorch(backend, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/mmengine/dist/utils.py", line 94, in _init_dist_pytorch
    init_dist(self.launcher, **dist_cfg)
  File "/usr/local/lib/python3.8/site-packages/mmengine/dist/utils.py", line 56, in init_dist
    runner = cls(
  File "/usr/local/lib/python3.8/site-packages/mmengine/runner/runner.py", line 345, in __init__
    torch_dist.init_process_group(backend=backend, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 539, in init_process_group
    _init_dist_pytorch(backend, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/mmengine/dist/utils.py", line 94, in _init_dist_pytorch
        init_dist(self.launcher, **dist_cfg)self.setup_env(env_cfg)

  File "/usr/local/lib/python3.8/site-packages/mmengine/runner/runner.py", line 644, in setup_env
  File "/usr/local/lib/python3.8/site-packages/mmengine/dist/utils.py", line 56, in init_dist
    torch_dist.init_process_group(backend=backend, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 539, in init_process_group
    raise RuntimeError(
RuntimeError: Expected timeout argument to be of typedatetime.timedelta
    _init_dist_pytorch(backend, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/mmengine/dist/utils.py", line 94, in _init_dist_pytorch
        torch_dist.init_process_group(backend=backend, **kwargs)init_dist(self.launcher, **dist_cfg)

  File "/usr/local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 539, in init_process_group
  File "/usr/local/lib/python3.8/site-packages/mmengine/dist/utils.py", line 56, in init_dist
    raise RuntimeError(
RuntimeError: Expected timeout argument to be of typedatetime.timedelta
    _init_dist_pytorch(backend, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/mmengine/dist/utils.py", line 94, in _init_dist_pytorch
    torch_dist.init_process_group(backend=backend, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 539, in init_process_group
    raise RuntimeError(
RuntimeError: Expected timeout argument to be of typedatetime.timedelta
    raise RuntimeError(
RuntimeError: Expected timeout argument to be of typedatetime.timedelta
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 197) of binary: /usr/local/bin/python
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/launch.py", line 193, in <module>
    main()
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/launch.py", line 189, in main
    launch(args)
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/launch.py", line 174, in launch
    run(args)
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/run.py", line 710, in run
    elastic_launch(
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 259, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

Additional information

What I expect to happen?

I expect that the distributed training is initialized with the specified timeout (in seconds)

Why do you need to set the timeout?

In our training, we have a large validation dataset which exceeds the default 30-minutes timeout and then causes the entire training to crash

What is the reason for the problem?

PyTorch expects a timedelta object to be provided during initialization. However, we can't just run replace the object with a timedelta object

timeout = datetime.timedelta(seconds=int(cfg.dist_params["timeout"]))
dist_params = cfg.dist_params.copy()
dist_params["timeout"] = timeout

because YAPL validation will crash when encountering a timedelta object instead of a primitive like a string or integer.

What is the solution?

Support a timeout in mmengine by automatically converting an integer into the required timedelta, e.g., in

def setup_env(self, env_cfg: Dict) -> None:
"""Setup environment.
An example of ``env_cfg``::
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(
mp_start_method='fork',
opencv_num_threads=0
),
dist_cfg=dict(backend='nccl'),
resource_limit=4096
)
Args:
env_cfg (dict): Config for setting environment.
"""
if env_cfg.get('cudnn_benchmark'):
torch.backends.cudnn.benchmark = True
mp_cfg: dict = env_cfg.get('mp_cfg', {})
set_multi_processing(**mp_cfg, distributed=self.distributed)
# init distributed env first, since logger depends on the dist info.
if self.distributed and not is_distributed():
dist_cfg: dict = env_cfg.get('dist_cfg', {})
init_dist(self.launcher, **dist_cfg)
self._rank, self._world_size = get_dist_info()
timestamp = torch.tensor(time.time(), dtype=torch.float64)
# broadcast timestamp from 0 process to other processes
broadcast(timestamp)
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
time.localtime(timestamp.item()))
# https://github.com/pytorch/pytorch/issues/973
# set resource limit
if platform.system() != 'Windows':
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
base_soft_limit = rlimit[0]
hard_limit = rlimit[1]
soft_limit = min(
max(env_cfg.get('resource_limit', 4096), base_soft_limit),
hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE,
(soft_limit, hard_limit))

like this:

def setup_env(self, env_cfg: Dict) -> None:
    # ...
    if self.distributed and not is_distributed():
        dist_cfg: dict = env_cfg.get('dist_cfg', {})
        if "timeout" in dist_cfg:
            timeout = datetime.timedelta(seconds=int(dist_cfg["timeout"]))
            dist_cfg["timeout"] = timeout
        init_dist(self.launcher, **dist_cfg)
    # ...
@apacha apacha added the bug Something isn't working label Jan 12, 2023
@HAOCHENYE
Copy link
Collaborator

HAOCHENYE commented Jan 12, 2023

Thanks for your feedback, the solutions sounds gread to me 😄 ! would you mind posting a PR to support configuring the timeout?

@apacha
Copy link
Contributor Author

apacha commented Jan 12, 2023

Ok, I will create a PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants