Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] use cpu/gloo to initialize pynccl #4248

Merged
merged 14 commits into from
Apr 24, 2024
15 changes: 10 additions & 5 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.utils import update_environment_variables


Expand All @@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
for p in processes:
p.join()

for p in processes:
assert p.exitcode == 0


def update_env(fn):
def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
def wrapped_fn(env):
update_environment_variables(env)
init_distributed_environment()
fn()

return wrapper
return wrapped_fn


@update_env
@worker_fn_wrapper
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
Expand All @@ -53,7 +58,7 @@ def test_pynccl():
distributed_run(worker_fn, 2)


@update_env
@worker_fn_wrapper
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
Expand Down
119 changes: 68 additions & 51 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
# variable in the code.

import ctypes
import datetime
import platform
from typing import Optional, Union

# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from torch.distributed import ProcessGroup, ReduceOp

from vllm.distributed.parallel_state import get_cpu_world_group
from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check

Expand Down Expand Up @@ -59,6 +60,18 @@

ncclResult_t = ctypes.c_int

_c_ncclGetErrorString = nccl.ncclGetErrorString
_c_ncclGetErrorString.restype = ctypes.c_char_p
_c_ncclGetErrorString.argtypes = [ncclResult_t]


def NCCL_CHECK(result: ncclResult_t) -> None:
if result != 0:
error_str = _c_ncclGetErrorString(result)
error_str = error_str.decode("utf-8")
raise RuntimeError(f"NCCL error: {error_str}")


# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion
Expand All @@ -68,8 +81,7 @@

def ncclGetVersion() -> str:
version = ctypes.c_int()
result = _c_ncclGetVersion(ctypes.byref(version))
assert result == 0
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
# something like 21903 --> "2.19.3"
version_str = str(version.value)
major = version_str[0].lstrip("0")
Expand All @@ -91,8 +103,7 @@ class NcclUniqueId(ctypes.Structure):

def ncclGetUniqueId() -> NcclUniqueId:
unique_id = NcclUniqueId()
result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
assert result == 0
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
return unique_id


Expand Down Expand Up @@ -199,66 +210,72 @@ class NCCLCommunicator:

def __init__(
self,
backend=None,
init_method=None,
timeout=datetime.timedelta(seconds=10),
world_size: int = -1,
rank: int = -1,
store=None,
group_name: str = "",
pg_options=None,
local_rank: int = -1,
group: Optional[ProcessGroup] = None,
device: Optional[Union[int, str, torch.device]] = None,
):
if not dist.is_initialized():
backend = backend or "nccl"
assert backend == 'nccl', (
"only use nccl backend for starting the NCCL communicator")
dist.init_process_group(backend=backend,
init_method=init_method,
timeout=timeout,
world_size=world_size,
rank=rank,
store=store,
group_name=group_name,
pg_options=pg_options)
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
if local_rank == -1:
local_rank = self.rank
self.local_rank = local_rank
# don't use these args, as they can be -1
# use `self.rank`, `self.local_rank` and `self.world_size` instead
del world_size, rank, local_rank
torch.cuda.set_device(self.local_rank)
"""
NCCLCommunicator has to be bind to a specific device.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
By default, it will be bind to f"cuda:{local_rank}".
"""
assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"NCCLCommunicator should be attached to a non-NCCL group.")
self.group = group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
if self.rank == 0:
self.unique_id = ncclGetUniqueId()
else:
self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
self.local_rank)
dist.broadcast(tensor, src=0)
byte_list = tensor.cpu().tolist()
tensor = torch.ByteTensor(list(self.unique_id.internal))
dist.broadcast(tensor, src=0, group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank)
assert result == 0
self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
if device is None:
local_rank = self.rank % torch.cuda.device_count()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
current_device = torch.cuda.current_device()
try:
torch.cuda.set_device(device)
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.cuda.Stream()
finally:
torch.cuda.set_device(current_device)
Comment on lines +253 to +261
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a try...finally block here? Can the program continue to run when there is an exception in the try block?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can, but I think it would be better for a function to be pure, i.e. don't implicitly modify some global state.


def all_reduce(self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream))
assert result == 0
NCCL_CHECK(
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream)))

def __del__(self):
# `dist` module might have been already destroyed
Expand Down
12 changes: 3 additions & 9 deletions vllm/distributed/device_communicators/pynccl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import torch
from torch.distributed import ReduceOp
from torch.distributed import ProcessGroup, ReduceOp

from vllm.logger import init_logger

Expand Down Expand Up @@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
pass


def init_process_group(world_size: int,
rank: int,
init_method: str,
local_rank: int = -1) -> None:
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
assert not is_initialized()
global comm
logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
comm = NCCLCommunicator(init_method=init_method,
world_size=world_size,
local_rank=local_rank,
rank=rank)
comm = NCCLCommunicator(group=group)


def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
Expand Down
6 changes: 6 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib
import os
from typing import Optional

import torch
Expand Down Expand Up @@ -73,6 +74,11 @@ def init_distributed_environment(
ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo")
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1 and distributed_init_method == "env://":
local_rank = int(os.environ['LOCAL_RANK'])
global _LOCAL_RANK
_LOCAL_RANK = local_rank

Expand Down
9 changes: 3 additions & 6 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,9 @@ def init_worker_distributed_environment(
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
pynccl_utils.init_process_group(
world_size=parallel_config.world_size,
local_rank=local_rank,
rank=rank,
init_method=distributed_init_method,
)
# NOTE(kaichao): By default, pynccl will use information inside
# `parallel_state` for initialization.
pynccl_utils.init_process_group()

ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
Expand Down
Loading