Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Core] fully composible launcher/task/coordinator/communicator design and implementation #3762

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 59 additions & 68 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,83 @@
import multiprocessing
import os

import pytest
import torch

from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
ncclGetUniqueId)


def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
env = os.environ.copy()
env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost'
env['MASTER_PORT'] = '12345'
p = multiprocessing.Process(target=fn, args=(env, ))
processes.append(p)
p.start()

for p in processes:
p.join()
from vllm.implementations.communicator import CommunicatorType
from vllm.implementations.coordinator import CoordinatorType
from vllm.implementations.distributed_tasks import (
GlobalCoordinatorDistributedTask, GroupCoordinatorDistributedTask)
from vllm.implementations.launcher.mp_launcher import MPLauncher


def update_env(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
import os
os.environ.update(env)
fn()
class AllReduceDistributedTask(GlobalCoordinatorDistributedTask):

return wrapper


@update_env
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == comm.world_size
def post_init_distributed(self, **kwargs):
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(
self.coordinator.get_local_rank())
self.communicator.all_reduce(tensor_in=tensor)
result = tensor.mean().cpu().item()
assert result == self.coordinator.get_local_world_size()


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
distributed_run(worker_fn, 2)
MPLauncher(n_tasks=2).launch(
task_type=AllReduceDistributedTask,
coordinator_type=CoordinatorType.TORCH_DISTRIBUTED,
communicator_type=CommunicatorType.PYNCCL,
)


@update_env
def worker_fn_with_cudagraph():
with torch.no_grad():
class CUDAGraphAllReduceDistributedTask(GlobalCoordinatorDistributedTask):

def post_init_distributed(self, **kwargs):
graph = torch.cuda.CUDAGraph()
comm = NCCLCommunicator()
device = f'cuda:{self.coordinator.get_rank()}'
stream = torch.cuda.Stream(device=device)

# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
a = torch.ones((4, 4), device=device)
torch.cuda.synchronize()
with torch.cuda.graph(graph, stream=comm.stream):
with torch.cuda.graph(graph, stream=stream):
# operation during the graph capture is recorded but not executed
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
comm.all_reduce(a)
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**0
self.communicator.all_reduce(a, stream=stream)
stream.synchronize()
assert a.mean().cpu().item() == self.coordinator.get_world_size()**0
graph.replay()
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**1
stream.synchronize()
assert a.mean().cpu().item() == self.coordinator.get_world_size()**1


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
distributed_run(worker_fn_with_cudagraph, 2)


def test_ncclGetUniqueId():
unique_id = ncclGetUniqueId()
# `list(unique_id.internal)` is something like this:
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# as long as the function doesn't raise an exception, we're good
assert unique_id is not None
MPLauncher(n_tasks=2).launch(
task_type=CUDAGraphAllReduceDistributedTask,
coordinator_type=CoordinatorType.TORCH_DISTRIBUTED,
communicator_type=CommunicatorType.PYNCCL,
)


class GroupedAllReduceDistributedTask(GroupCoordinatorDistributedTask):

def post_init_distributed(self, **kwargs):
rank = self.global_coordinator.get_local_rank()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda() * rank
self.communicator.all_reduce(tensor_in=tensor)
result = tensor.mean().cpu().item()
if rank in [0, 1]:
assert result == 1
else:
assert result == 5


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
def test_grouped_pynccl():
MPLauncher(n_tasks=4).launch(
task_type=GroupedAllReduceDistributedTask,
coordinator_type=CoordinatorType.TORCH_DISTRIBUTED,
communicator_type=CommunicatorType.PYNCCL,
groups=[[0, 1], [2, 3]],
)
17 changes: 17 additions & 0 deletions vllm/implementations/communicator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from enum import Enum, auto


class CommunicatorType(Enum):
PYNCCL = auto()


def get_communicator_class(communicator_type: CommunicatorType) -> type:
# lazy init
# only import the communicator when it is needed
if communicator_type == CommunicatorType.PYNCCL:
from vllm.implementations.communicator.nccl.pynccl.pynccl_communicator import ( # noqa
NCCLCommunicator)
return NCCLCommunicator
else:
raise ValueError(
f"Communicator type {communicator_type} not regonized.")
Empty file.
8 changes: 8 additions & 0 deletions vllm/implementations/communicator/nccl/pynccl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from vllm.implementations.communicator.nccl.pynccl.pynccl_communicator import ( # noqa
NCCLCommunicator, get_pynccl_path, set_pynccl_path)

__all__ = [
"NCCCommunicator",
"get_pynccl_path",
"set_pynccl_path",
]
140 changes: 140 additions & 0 deletions vllm/implementations/communicator/nccl/pynccl/pynccl_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import logging
import os
from contextlib import contextmanager
from typing import Any, Optional

import torch

from vllm.implementations.communicator.nccl.pynccl.wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclDataType_t, ncclRedOp_t,
ncclUniqueId)
from vllm.interfaces.communicator import Communicator, ReduceOp
from vllm.interfaces.coordinator import Coordinator

logger = logging.getLogger(__name__)

# script to manage the path of the nccl library

so_file: Optional[str] = None


def set_pynccl_path(path: str) -> None:
global so_file
so_file = path


def get_pynccl_path() -> Optional[str]:
return so_file


@contextmanager
def change_pynccl_path(path: str) -> None:
global so_file
old_path = so_file
so_file = path
yield
so_file = old_path


class NCCLCommunicator(Communicator):

def __init__(
self,
coordinator: Coordinator,
path_of_nccl: str = None,
):
if "CUDA_VISIBLE_DEVICES" in os.environ:
visible_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
assert len(visible_gpus) >= coordinator.get_group_size(), \
(f"Number of visible GPUs {len(visible_gpus)} is less than"
f" the number of processes in the group {coordinator.group}.")

super().__init__(coordinator)

# search priority:
# 1. path_of_nccl (passed in the constructor of NCCLCommunicator)
# 2. so_file (set by users calling `set_pynccl_path`)
# 3. VLLM_NCCL_SO_PATH environment variable
# 4. default path
path_of_nccl = path_of_nccl or so_file or os.environ.get(
"VLLM_NCCL_SO_PATH", "")
if not path_of_nccl:
# not set yet, try a decent guess as default
if torch.version.cuda is not None:
path_of_nccl = "libnccl.so.2"
elif torch.version.hip is not None:
path_of_nccl = "librccl.so.1"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.debug(f"Loading nccl from library {so_file}")

try:
self.lib = NCCLLibrary(path_of_nccl)
except Exception as e:
logger.error(
f"Failed to load NCCL library from {path_of_nccl} ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise please set environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.")
raise e

logger.info(f"vLLM is using nccl=={self.lib.ncclGetVersion()}")
local_rank = coordinator.get_local_rank()
torch.cuda.set_device(local_rank)
self.stream = torch.cuda.Stream(device=f"cuda:{local_rank}")
if coordinator.is_group_master():
# get a unique id by calling nccl library
self.unique_id = self.lib.ncclGetUniqueId()
else:
# default initialization of unique_id
self.unique_id = ncclUniqueId()
data = bytearray(self.unique_id.internal)
coordinator.broadcast(data, src=coordinator.get_group_master_rank())
for i in range(len(data)):
self.unique_id.internal[i] = data[i]
nrank = coordinator.get_group_size()
rank = coordinator.get_group_rank()
self.comm = self.lib.ncclCommInitRank(nrank, self.unique_id, rank)

@staticmethod
def convert_reduce_op(op: ReduceOp) -> ncclRedOp_t:
return {
ReduceOp.SUM: ncclRedOp_t.ncclSum,
ReduceOp.PRODUCT: ncclRedOp_t.ncclProd,
ReduceOp.MAX: ncclRedOp_t.ncclMax,
ReduceOp.MIN: ncclRedOp_t.ncclMin,
ReduceOp.AVG: ncclRedOp_t.ncclAvg,
}[op]

@staticmethod
def convert_data_type(dtype: torch.dtype) -> ncclDataType_t:
return {
torch.int8: ncclDataType_t.ncclInt8,
torch.uint8: ncclDataType_t.ncclUint8,
torch.int32: ncclDataType_t.ncclInt32,
torch.int64: ncclDataType_t.ncclInt64,
torch.float16: ncclDataType_t.ncclFloat16,
torch.float32: ncclDataType_t.ncclFloat32,
torch.float64: ncclDataType_t.ncclFloat64,
torch.bfloat16: ncclDataType_t.ncclBfloat16,
}[dtype]

def all_reduce(self,
tensor_in: torch.Tensor,
tensor_out: Optional[torch.Tensor] = None,
op: ReduceOp = ReduceOp.SUM,
stream: Optional[Any] = None):
assert tensor_in.is_cuda and tensor_in.is_contiguous()
if tensor_out is None:
tensor_out = tensor_in
op = self.convert_reduce_op(op)
dtype = self.convert_data_type(tensor_in.dtype)
if stream is None:
stream = self.stream
self.lib.ncclAllReduce(buffer_type(tensor_in.data_ptr()),
buffer_type(tensor_out.data_ptr()),
tensor_in.numel(), dtype, op, self.comm,
cudaStream_t(stream.cuda_stream))

def __del__(self):
self.lib.ncclCommDestroy(self.comm)
Loading
Loading