-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] ref: decoupled ddp, ddp spawn #3733
Changes from 111 commits
767b8ab
f746018
c3529ee
3497f0d
c4a9dc0
09bf2a6
81d7a0d
54a7402
2960aa2
417242c
b751f3a
e40a7c2
78bf07b
07efc8e
1276a51
d4b9f37
3041561
61ab801
7eeaa64
416a96d
b4454ee
f151c21
4278731
2e9c537
6f6f4fa
dab971d
95aaca6
b46874c
424a6db
35d01e4
f6e0bbe
a0542ae
64a486c
d124a94
3fa5ad2
2e49563
8acddd7
50a9c8b
5fc4912
2070075
f0c06bd
08b0cad
8a8a0bf
ed675ef
336bb47
c3f299a
e4cb76d
94ef3b9
357d640
e49c8a1
91736e2
15e5be0
b37d948
51370ce
23032ea
9f8705a
0f13e61
7ccabd8
9171464
1d4aeaa
b96d7c1
85050a3
506b037
63f5d50
01dd4c5
a0f52d7
650903a
cbd89f7
8ebd4ed
1f19c2f
ea448bb
fbeec9e
7663c6b
9421dbb
cf08480
f0c3cc5
459a0fa
64484a1
10bae5b
667c434
5b412e0
d9fc538
b2e941c
5ac3e59
3650f86
da582ab
471b576
545bf01
7b72cd6
1fbc1ca
c5c9faf
701f233
4a7368a
7169107
27e5870
455a488
6c3732c
73f0ef3
e36e20f
2f93660
1fb466c
202e82e
c8bd6ee
d4d8551
5acef3e
288fd23
0dcdd81
581e929
fe53c9a
c644f66
2a10f59
beacd6a
7e98763
661cfb0
69235e9
c958ec7
6088c48
f86ab63
2c2755c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -13,17 +13,22 @@ | |||||
# limitations under the License | ||||||
|
||||||
import os | ||||||
import torch | ||||||
import torch.distributed as torch_distrib | ||||||
import subprocess | ||||||
import sys | ||||||
from os.path import abspath | ||||||
from time import sleep | ||||||
from typing import Optional | ||||||
|
||||||
import numpy as np | ||||||
import torch | ||||||
|
||||||
from pytorch_lightning import _logger as log | ||||||
from pytorch_lightning.utilities.distributed import find_free_network_port | ||||||
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase | ||||||
from pytorch_lightning.accelerators.base_backend import Accelerator | ||||||
from pytorch_lightning.utilities.distributed import rank_zero_only | ||||||
from pytorch_lightning.utilities import AMPType | ||||||
|
||||||
|
||||||
try: | ||||||
from hydra.utils import to_absolute_path, get_original_cwd | ||||||
|
@@ -34,13 +39,14 @@ | |||||
HYDRA_AVAILABLE = True | ||||||
|
||||||
|
||||||
class DDPBackend(DDPBase): | ||||||
class DDPBackend(Accelerator): | ||||||
|
||||||
def __init__(self, trainer, mode: str = 'ddp'): | ||||||
super().__init__(trainer) | ||||||
self.task_idx = None | ||||||
self._has_spawned_children = False | ||||||
self.mode = mode | ||||||
self.interactive_ddp_procs = [] | ||||||
|
||||||
def setup(self, model): | ||||||
if self.mode == 'ddp': | ||||||
|
@@ -59,6 +65,10 @@ def __torchelastic_setup(self): | |||||
self.task_idx = int(os.environ['LOCAL_RANK']) | ||||||
|
||||||
def __ddp_script_mode_setup(self): | ||||||
# do nothing when already in a ddp subprocess | ||||||
if os.environ.get('PL_IN_DDP_SUBPROCESS', '0') == '1': | ||||||
return | ||||||
|
||||||
assert self.trainer.global_rank == 0 | ||||||
self._check_can_spawn_children() | ||||||
self._has_spawned_children = True | ||||||
|
@@ -91,21 +101,27 @@ def __ddp_script_mode_setup(self): | |||||
# when the trainer script was called the device has already been scoped by the time | ||||||
# code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone | ||||||
# but forward the GPUs selected via environment variables | ||||||
# set the flag for ddp scripts | ||||||
|
||||||
os.environ['PL_TRAINER_GPUS'] = ','.join([str(i) for i in self.trainer.data_parallel_device_ids]) | ||||||
os.environ['PL_IN_DDP_SUBPROCESS'] = '1' | ||||||
|
||||||
if self.trainer.logger is not None: | ||||||
os.environ['PL_EXP_VERSION'] = str(self.trainer.logger.version) | ||||||
|
||||||
gpu_ids = os.environ.get('CUDA_VISIBLE_DEVICES', '') | ||||||
if len(gpu_ids) == 1: | ||||||
gpu_ids = f'{gpu_ids},' | ||||||
|
||||||
num_gpus = max(1, len(gpu_ids.split(','))) | ||||||
|
||||||
# set the flag for ddp scripts | ||||||
os.environ['PL_TRAINER_GPUS'] = gpu_ids | ||||||
|
||||||
os.environ['WORLD_SIZE'] = f'{num_gpus * self.trainer.num_nodes}' | ||||||
|
||||||
self.trainer.interactive_ddp_procs = [] | ||||||
self.interactive_ddp_procs = [] | ||||||
for local_rank in range(1, self.trainer.num_processes): | ||||||
env_copy = os.environ.copy() | ||||||
env_copy['LOCAL_RANK'] = f'{local_rank}' | ||||||
env_copy['PL_DDP_PID'] = str(self.trainer.data_parallel_device_ids[local_rank]) | ||||||
|
||||||
# start process | ||||||
# if hydra is available and initialized, make sure to set the cwd correctly | ||||||
|
@@ -114,7 +130,7 @@ def __ddp_script_mode_setup(self): | |||||
if HydraConfig.initialized(): | ||||||
cwd = get_original_cwd() | ||||||
proc = subprocess.Popen(command, env=env_copy, cwd=cwd) | ||||||
self.trainer.interactive_ddp_procs.append(proc) | ||||||
self.interactive_ddp_procs.append(proc) | ||||||
|
||||||
# starting all processes at once can cause issues | ||||||
# with dataloaders delay between 1-10 seconds | ||||||
|
@@ -123,14 +139,116 @@ def __ddp_script_mode_setup(self): | |||||
|
||||||
self.task_idx = 0 | ||||||
|
||||||
# wait for all the procs to start | ||||||
sleep(2) | ||||||
|
||||||
def train(self): | ||||||
model = self.trainer.model | ||||||
if self.mode == 'ddp': | ||||||
results = self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model, is_master=True) | ||||||
del os.environ['WORLD_SIZE'] | ||||||
results = self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model, is_master=True) | ||||||
if 'WORLD_SIZE' in os.environ: | ||||||
del os.environ['WORLD_SIZE'] | ||||||
return results | ||||||
else: | ||||||
self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model) | ||||||
return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) | ||||||
|
||||||
def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): | ||||||
""" | ||||||
Entry point for ddp | ||||||
Args: | ||||||
process_idx: | ||||||
mp_queue: multiprocessing queue | ||||||
model: | ||||||
is_master: | ||||||
proc_offset: | ||||||
Returns: | ||||||
""" | ||||||
# offset the process id if requested | ||||||
process_idx = process_idx + proc_offset | ||||||
|
||||||
# show progressbar only on progress_rank 0 | ||||||
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: | ||||||
self.trainer.progress_bar_callback.disable() | ||||||
|
||||||
# determine which process we are and world size | ||||||
self.set_world_ranks(process_idx) | ||||||
|
||||||
# set warning rank | ||||||
rank_zero_only.rank = self.trainer.global_rank | ||||||
|
||||||
# set up server using proc 0's ip address | ||||||
# try to init for 20 times at max in case ports are taken | ||||||
# where to store ip_table | ||||||
model.trainer = self.trainer | ||||||
model.init_ddp_connection( | ||||||
self.trainer.global_rank, | ||||||
self.trainer.world_size, | ||||||
self.trainer.is_slurm_managing_tasks | ||||||
) | ||||||
|
||||||
# call setup after the ddp process has connected | ||||||
self.trainer.call_setup_hook(model) | ||||||
|
||||||
# on world_size=0 let everyone know training is starting | ||||||
if self.trainer.is_global_zero and not torch.distributed.is_initialized(): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just thinking, but is_global_zero should be part of DDPBackend IMO, since this is only needed for this |
||||||
log.info('-' * 100) | ||||||
log.info(f'distributed_backend={self.trainer.distributed_backend}') | ||||||
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') | ||||||
log.info('-' * 100) | ||||||
|
||||||
# call sync_bn before .cuda(), configure_apex and configure_ddp | ||||||
if self.trainer.sync_batchnorm: | ||||||
model = model.configure_sync_batchnorm(model) | ||||||
|
||||||
# MODEL | ||||||
# copy model to each gpu | ||||||
self.model_to_device(model, process_idx, is_master) | ||||||
|
||||||
# CHOOSE OPTIMIZER | ||||||
# allow for lr schedulers as well | ||||||
self.setup_optimizers(model) | ||||||
|
||||||
# set model properties before going into wrapper | ||||||
self.trainer.model_connector.copy_trainer_model_properties(model) | ||||||
|
||||||
# AMP - run through amp wrapper before going to distributed DP | ||||||
# DDP uses all GPUs on the machine | ||||||
device_ids = self.get_device_ids() | ||||||
|
||||||
# allow user to configure ddp | ||||||
model = model.configure_ddp(model, device_ids) | ||||||
|
||||||
# set up training routine | ||||||
self.barrier('ddp_setup') | ||||||
self.trainer.train_loop.setup_training(model) | ||||||
|
||||||
# train or test | ||||||
results = self.train_or_test() | ||||||
|
||||||
# clean up memory | ||||||
torch.cuda.empty_cache() | ||||||
|
||||||
return results | ||||||
|
||||||
def training_step(self, args): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
changing this to positional args allows for inspection. Since this is also passed as positional args, this should be okay. Just have to check the according calls as well |
||||||
if self.trainer.amp_backend == AMPType.NATIVE: | ||||||
with torch.cuda.amp.autocast(): | ||||||
output = self.trainer.model(*args) | ||||||
else: | ||||||
output = self.trainer.model(*args) | ||||||
return output | ||||||
|
||||||
def validation_step(self, args): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks like the preroutine we had before. can we rename There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
output = self.training_step(args) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
return output | ||||||
|
||||||
def test_step(self, args): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
output = self.training_step(args) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
return output | ||||||
|
||||||
def barrier(self, name: str = None): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the name argument not used here? |
||||||
if torch_distrib.is_initialized(): | ||||||
torch_distrib.barrier() | ||||||
|
||||||
def _check_can_spawn_children(self): | ||||||
if self._has_spawned_children: | ||||||
|
@@ -145,15 +263,7 @@ def set_world_ranks(self, process_idx): | |||||
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes | ||||||
|
||||||
def model_to_device(self, model, process_idx, is_master): | ||||||
gpu_idx = process_idx | ||||||
|
||||||
# when using ddp, the master process (proc 0) continues running as the main one | ||||||
# this means that the local rank will always be 0 | ||||||
# (even if cuda visible devices has other visible gpus) | ||||||
# this means that the master process needs to pull the 0th visible index as the device number | ||||||
if is_master: | ||||||
available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') | ||||||
gpu_idx = int(available_gpus[self.trainer.local_rank]) | ||||||
gpu_idx = int(os.environ.get('PL_DDP_PID', process_idx)) | ||||||
|
||||||
self.trainer.root_gpu = gpu_idx | ||||||
torch.cuda.set_device(self.trainer.root_gpu) | ||||||
|
@@ -162,3 +272,6 @@ def model_to_device(self, model, process_idx, is_master): | |||||
def get_device_ids(self): | ||||||
device_ids = [self.trainer.root_gpu] | ||||||
return device_ids | ||||||
|
||||||
def on_train_end(self): | ||||||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,7 +58,8 @@ def test_step(self, args): | |
return output | ||
|
||
def barrier(self, name: str = None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. again unused name argument There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this class is getting dropped in a few PRs. |
||
torch_distrib.barrier() | ||
if torch_distrib.is_initialized(): | ||
torch_distrib.barrier() | ||
|
||
def early_stopping_should_stop(self, pl_module): | ||
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) | ||
|
@@ -132,7 +133,7 @@ def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offs | |
self.trainer.call_setup_hook(model) | ||
|
||
# on world_size=0 let everyone know training is starting | ||
if self.trainer.is_global_zero: | ||
if self.trainer.is_global_zero and not torch.distributed.is_initialized(): | ||
log.info('-' * 100) | ||
log.info(f'distributed_backend={self.trainer.distributed_backend}') | ||
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -95,9 +95,7 @@ def ddp_train(self, process_idx, mp_queue, model): | |||||||
self.trainer.progress_bar_callback.disable() | ||||||||
|
||||||||
# determine which process we are and world size | ||||||||
self.trainer.local_rank = process_idx | ||||||||
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx | ||||||||
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes | ||||||||
self.set_world_ranks(process_idx) | ||||||||
|
||||||||
# set warning rank | ||||||||
rank_zero_only.rank = self.trainer.global_rank | ||||||||
|
@@ -116,7 +114,7 @@ def ddp_train(self, process_idx, mp_queue, model): | |||||||
self.trainer.call_setup_hook(model) | ||||||||
|
||||||||
# on world_size=0 let everyone know training is starting | ||||||||
if self.trainer.is_global_zero: | ||||||||
if self.trainer.is_global_zero and not torch.distributed.is_initialized(): | ||||||||
log.info('-' * 100) | ||||||||
log.info(f'distributed_backend={self.trainer.distributed_backend}') | ||||||||
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') | ||||||||
|
@@ -126,6 +124,9 @@ def ddp_train(self, process_idx, mp_queue, model): | |||||||
if self.trainer.sync_batchnorm: | ||||||||
model = model.configure_sync_batchnorm(model) | ||||||||
|
||||||||
# move the model to the correct device | ||||||||
self.model_to_device(model, process_idx) | ||||||||
|
||||||||
# CHOOSE OPTIMIZER | ||||||||
# allow for lr schedulers as well | ||||||||
self.setup_optimizers(model) | ||||||||
|
@@ -137,7 +138,7 @@ def ddp_train(self, process_idx, mp_queue, model): | |||||||
model = self.trainer.precision_connector.connect(model) | ||||||||
|
||||||||
# DDP spawn already spawned off each process... no need to do anything | ||||||||
device_ids = None | ||||||||
device_ids = self.get_device_ids() | ||||||||
|
||||||||
# allow user to configure ddp | ||||||||
model = model.configure_ddp(model, device_ids) | ||||||||
|
@@ -174,7 +175,8 @@ def test_step(self, args): | |||||||
return output | ||||||||
|
||||||||
def barrier(self, name: str = None): | ||||||||
torch_distrib.barrier() | ||||||||
if torch_distrib.is_initialized(): | ||||||||
williamFalcon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
torch_distrib.barrier() | ||||||||
|
||||||||
def broadcast(self, obj, src=0): | ||||||||
return self.dist.broadcast(obj) | ||||||||
|
@@ -186,6 +188,19 @@ def early_stopping_should_stop(self, pl_module): | |||||||
should_stop = stop == self.trainer.world_size | ||||||||
return should_stop | ||||||||
|
||||||||
def set_world_ranks(self, process_idx): | ||||||||
self.trainer.local_rank = process_idx | ||||||||
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx | ||||||||
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes | ||||||||
|
||||||||
def model_to_device(self, model, process_idx): | ||||||||
# in ddp cpu we don't actually move models to a device | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should explicitly move them to cpu here, since we don't know on which device it was initially. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good call |
||||||||
pass | ||||||||
|
||||||||
def get_device_ids(self): | ||||||||
device_ids = None | ||||||||
return device_ids | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): | ||||||||
# track the best model path | ||||||||
best_model_path = None | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the processes communicate on startup? I feel like a hardcoded sleep is not the optimal solution here