Skip to content

Commit

Permalink
Add distributed multi-node cpu only support (MULTI_CPU) (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
ddkalamk authored May 4, 2021
1 parent 78b7753 commit df260fa
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 7 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ For instance, here is how you would run the GLUE example on the MRPC task (from
accelerate launch examples/nlp_example.py
```

## Launching multi-CPU run using MPI

🤗 Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well.
Once you have MPI setup on your cluster, just run:

```bash
mpirun -np 2 python examples/nlp_example.py
```

## Launching your training from a notebook

🤗 Accelerate also provides a `notebook_launcher` function you can use in a notebook to launch a distributed training. This is especially useful for Colab or Kaggle notebooks with a TPU backend. Just define your training loop in a `training_function` then in your last cell, add:
Expand Down Expand Up @@ -188,6 +197,8 @@ pip install accelerate
## Supported integrations

- CPU only
- multi-CPU on one node (machine)
- multi-CPU on several nodes (machines)
- single GPU
- multi-GPU on one node (machine)
- multi-GPU on several nodes (machines)
Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ def prepare_model(self, model):
output_device=self.local_process_index,
**kwargs,
)
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
if self.native_amp:
model.forward = torch.cuda.amp.autocast()(model.forward)
return model
Expand Down
6 changes: 3 additions & 3 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@

def get_cluster_input():
distributed_type = _ask_field(
"Which type of machine are you using? ([0] No distributed training, [1] multi-GPU, [2] TPU): ",
"Which type of machine are you using? ([0] No distributed training, [1] multi-CPU, [2] multi-GPU, [3] TPU): ",
_convert_distributed_mode,
error_message="Please enter 0, 1 or 2.",
error_message="Please enter 0, 1, 2 or 3.",
)

machine_rank = 0
num_machines = 1
main_process_ip = None
main_process_port = None
if distributed_type == DistributedType.MULTI_GPU:
if distributed_type == DistributedType.MULTI_GPU or distributed_type == DistributedType.MULTI_CPU:
num_machines = _ask_field(
"How many different machines will you use (use more than 1 for multi-node training)? [1]: ",
lambda x: int(x),
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/config/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _convert_compute_environment(value):

def _convert_distributed_mode(value):
value = int(value)
return DistributedType(["NO", "MULTI_GPU", "TPU"][value])
return DistributedType(["NO", "MULTI_CPU", "MULTI_GPU", "TPU"][value])


def _convert_sagemaker_distributed_mode(value):
Expand Down
64 changes: 63 additions & 1 deletion src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
import torch


try:
import torch_ccl # noqa: F401

_ccl_available = True
except ImportError:
_ccl_available = False


try:
import torch_xla.core.xla_model as xm

Expand All @@ -27,6 +35,19 @@
_tpu_available = False


def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default."""
for e in env_keys:
val = int(os.environ.get(e, -1))
if val >= 0:
return val
return default


def is_ccl_available():
return _ccl_available


def is_tpu_available():
return _tpu_available

Expand All @@ -43,12 +64,14 @@ class DistributedType(str, Enum):
Values:
- **NO** -- Not a distributed environment, just a single process.
- **MULTI_CPU** -- Distributed on multiple CPU nodes.
- **MULTI_GPU** -- Distributed on multiple GPUs.
- **TPU** -- Distributed on TPUs.
"""

# Subclassing str as well as Enum allows the `DistributedType` to be JSON-serializable out of the box.
NO = "NO"
MULTI_CPU = "MULTI_CPU"
MULTI_GPU = "MULTI_GPU"
TPU = "TPU"

Expand Down Expand Up @@ -107,6 +130,7 @@ class AcceleratorState:
def __init__(self, fp16: bool = None, cpu: bool = False, _from_accelerator: bool = False):
self.__dict__ = self._shared_state
if not getattr(self, "initialized", False):
self.backend = None
if not _from_accelerator:
raise ValueError(
"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
Expand All @@ -123,12 +147,50 @@ def __init__(self, fp16: bool = None, cpu: bool = False, _from_accelerator: bool
self.distributed_type = DistributedType.MULTI_GPU
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
self.backend = "nccl"
self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
self.device = torch.device("cuda", self.local_process_index)
torch.cuda.set_device(self.device)
self.use_fp16 = parse_flag_from_env("USE_FP16", False) if fp16 is None else fp16
elif get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1:
self.distributed_type = DistributedType.MULTI_CPU
if is_ccl_available() and get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0:
backend = "ccl"
elif torch.distributed.is_mpi_available():
backend = "mpi"
else:
backend = "gloo"
# Try to get launch configuration from environment variables set by MPI launcher - works for Intel MPI, OpenMPI and MVAPICH
rank = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0)
size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1)
local_rank = get_int_from_env(
["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0
)
local_size = get_int_from_env(
["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1
)
self.local_process_index = local_rank
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(size)
os.environ["LOCAL_RANK"] = str(local_rank)
if not os.environ.get("MASTER_PORT", None):
os.environ["MASTER_PORT"] = "29500"
if not os.environ.get("MASTER_ADDR", None):
if local_size != size and backend != "mpi":
raise ValueError(
"Looks like distributed multinode run but MASTER_ADDR env not set, "
"please try exporting rank 0's hostname as MASTER_ADDR"
)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend, rank=rank, world_size=size)
self.backend = backend
self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
self.local_process_index = local_rank
self.device = torch.device("cpu")
self.use_fp16 = False
else:
self.distributed_type = DistributedType.NO
self.num_processes = 1
Expand All @@ -139,7 +201,7 @@ def __init__(self, fp16: bool = None, cpu: bool = False, _from_accelerator: bool

def __repr__(self):
return (
f"Distributed environment: {self.distributed_type}\n"
f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n"
f"Num processes: {self.num_processes}\n"
f"Process index: {self.process_index}\n"
f"Local process index: {self.local_process_index}\n"
Expand Down
14 changes: 12 additions & 2 deletions src/accelerate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optiona
rng_state = rng_state.to(state.device)
torch.distributed.broadcast(rng_state, 0)
rng_state = rng_state.cpu()
elif state.distributed_type == DistributedType.MULTI_CPU:
torch.distributed.broadcast(rng_state, 0)

# Set the broadcast rng state
if rng_type == RNGType.TORCH:
Expand Down Expand Up @@ -156,6 +158,9 @@ def _gpu_gather(tensor):
return torch.cat(output_tensors, dim=0)


_cpu_gather = _gpu_gather


def gather(tensor):
"""
Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices.
Expand All @@ -171,6 +176,8 @@ def gather(tensor):
return _tpu_gather(tensor, name="accelerate.utils.gather")
elif AcceleratorState().distributed_type == DistributedType.MULTI_GPU:
return _gpu_gather(tensor)
elif AcceleratorState().distributed_type == DistributedType.MULTI_CPU:
return _cpu_gather(tensor)
else:
return tensor

Expand Down Expand Up @@ -230,7 +237,10 @@ def wait_for_everyone():
Make sure all processes will reach this instruction otherwise one of your processes will hang forever.
"""
if AcceleratorState().distributed_type == DistributedType.MULTI_GPU:
if (
AcceleratorState().distributed_type == DistributedType.MULTI_GPU
or AcceleratorState().distributed_type == DistributedType.MULTI_CPU
):
torch.distributed.barrier()
elif AcceleratorState().distributed_type == DistributedType.TPU:
xm.rendezvous("accelerate.utils.wait_for_everyone")
Expand Down Expand Up @@ -266,7 +276,7 @@ def __init__(self, launcher, distributed_type="NO"):
self.distributed_type = DistributedType(distributed_type)

def __call__(self, index, *args):
if self.distributed_type == DistributedType.MULTI_GPU:
if self.distributed_type == DistributedType.MULTI_GPU or self.distributed_type == DistributedType.MULTI_CPU:
# Prepare the environment for torch.distributed
os.environ["LOCAL_RANK"] = str(index)
os.environ["RANK"] = str(index)
Expand Down

0 comments on commit df260fa

Please sign in to comment.