Skip to content

Commit

Permalink
refactor deepspeed setup devices (huggingface#9880)
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 authored and Qbiwan committed Jan 31, 2021
1 parent daddaf7 commit a02d250
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,20 @@ def _setup_devices(self) -> "torch.device":
self.local_rank = dist.get_local_rank()
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
elif self.deepspeed:
# deepspeed performs its own DDP internally, and requires the program to be started with:
# deepspeed ./program.py
# rather than:
# python -m torch.distributed.launch --nproc_per_node=2 ./program.py
from .integrations import is_deepspeed_available

if not is_deepspeed_available():
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
import deepspeed

deepspeed.init_distributed()
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
elif self.local_rank == -1:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
Expand All @@ -549,21 +563,7 @@ def _setup_devices(self) -> "torch.device":
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
#
# deepspeed performs its own DDP internally, and requires the program to be started with:
# deepspeed ./program.py
# rather than:
# python -m torch.distributed.launch --nproc_per_node=2 ./program.py
if self.deepspeed:
from .integrations import is_deepspeed_available

if not is_deepspeed_available():
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
import deepspeed

deepspeed.init_distributed()
else:
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1

Expand Down

0 comments on commit a02d250

Please sign in to comment.