diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 29782045130a6..b0b87ae83e68b 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -1,92 +1,83 @@ -import multiprocessing -import os - import pytest import torch -from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, - ncclGetUniqueId) - - -def distributed_run(fn, world_size): - number_of_processes = world_size - processes = [] - for i in range(number_of_processes): - env = os.environ.copy() - env['RANK'] = str(i) - env['LOCAL_RANK'] = str(i) - env['WORLD_SIZE'] = str(number_of_processes) - env['LOCAL_WORLD_SIZE'] = str(number_of_processes) - env['MASTER_ADDR'] = 'localhost' - env['MASTER_PORT'] = '12345' - p = multiprocessing.Process(target=fn, args=(env, )) - processes.append(p) - p.start() - - for p in processes: - p.join() +from vllm.implementations.communicator import CommunicatorType +from vllm.implementations.coordinator import CoordinatorType +from vllm.implementations.distributed_tasks import ( + GlobalCoordinatorDistributedTask, GroupCoordinatorDistributedTask) +from vllm.implementations.launcher.mp_launcher import MPLauncher -def update_env(fn): - # `multiprocessing.Process` cannot accept environment variables directly - # so we need to pass the environment variables as arguments - # and update the environment variables in the function - def wrapper(env): - import os - os.environ.update(env) - fn() +class AllReduceDistributedTask(GlobalCoordinatorDistributedTask): - return wrapper - - -@update_env -def worker_fn(): - comm = NCCLCommunicator() - tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) - comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == comm.world_size + def post_init_distributed(self, **kwargs): + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda( + self.coordinator.get_local_rank()) + self.communicator.all_reduce(tensor_in=tensor) + result = tensor.mean().cpu().item() + assert result == self.coordinator.get_local_world_size() @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") def test_pynccl(): - distributed_run(worker_fn, 2) + MPLauncher(n_tasks=2).launch( + task_type=AllReduceDistributedTask, + coordinator_type=CoordinatorType.TORCH_DISTRIBUTED, + communicator_type=CommunicatorType.PYNCCL, + ) -@update_env -def worker_fn_with_cudagraph(): - with torch.no_grad(): +class CUDAGraphAllReduceDistributedTask(GlobalCoordinatorDistributedTask): + + def post_init_distributed(self, **kwargs): graph = torch.cuda.CUDAGraph() - comm = NCCLCommunicator() + device = f'cuda:{self.coordinator.get_rank()}' + stream = torch.cuda.Stream(device=device) + # run something in the default stream to initialize torch engine - a = torch.ones((4, 4), device=f'cuda:{comm.rank}') + a = torch.ones((4, 4), device=device) torch.cuda.synchronize() - with torch.cuda.graph(graph, stream=comm.stream): + with torch.cuda.graph(graph, stream=stream): # operation during the graph capture is recorded but not executed # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa - comm.all_reduce(a) - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**0 + self.communicator.all_reduce(a, stream=stream) + stream.synchronize() + assert a.mean().cpu().item() == self.coordinator.get_world_size()**0 graph.replay() - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**1 + stream.synchronize() + assert a.mean().cpu().item() == self.coordinator.get_world_size()**1 @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") def test_pynccl_with_cudagraph(): - distributed_run(worker_fn_with_cudagraph, 2) - - -def test_ncclGetUniqueId(): - unique_id = ncclGetUniqueId() - # `list(unique_id.internal)` is something like this: - # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, - # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - # as long as the function doesn't raise an exception, we're good - assert unique_id is not None + MPLauncher(n_tasks=2).launch( + task_type=CUDAGraphAllReduceDistributedTask, + coordinator_type=CoordinatorType.TORCH_DISTRIBUTED, + communicator_type=CommunicatorType.PYNCCL, + ) + + +class GroupedAllReduceDistributedTask(GroupCoordinatorDistributedTask): + + def post_init_distributed(self, **kwargs): + rank = self.global_coordinator.get_local_rank() + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda() * rank + self.communicator.all_reduce(tensor_in=tensor) + result = tensor.mean().cpu().item() + if rank in [0, 1]: + assert result == 1 + else: + assert result == 5 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_grouped_pynccl(): + MPLauncher(n_tasks=4).launch( + task_type=GroupedAllReduceDistributedTask, + coordinator_type=CoordinatorType.TORCH_DISTRIBUTED, + communicator_type=CommunicatorType.PYNCCL, + groups=[[0, 1], [2, 3]], + ) diff --git a/vllm/implementations/communicator/__init__.py b/vllm/implementations/communicator/__init__.py new file mode 100644 index 0000000000000..ef9c2d241e1f5 --- /dev/null +++ b/vllm/implementations/communicator/__init__.py @@ -0,0 +1,17 @@ +from enum import Enum, auto + + +class CommunicatorType(Enum): + PYNCCL = auto() + + +def get_communicator_class(communicator_type: CommunicatorType) -> type: + # lazy init + # only import the communicator when it is needed + if communicator_type == CommunicatorType.PYNCCL: + from vllm.implementations.communicator.nccl.pynccl.pynccl_communicator import ( # noqa + NCCLCommunicator) + return NCCLCommunicator + else: + raise ValueError( + f"Communicator type {communicator_type} not regonized.") diff --git a/vllm/implementations/communicator/nccl/__init__.py b/vllm/implementations/communicator/nccl/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/implementations/communicator/nccl/pynccl/__init__.py b/vllm/implementations/communicator/nccl/pynccl/__init__.py new file mode 100644 index 0000000000000..5ab64774e45eb --- /dev/null +++ b/vllm/implementations/communicator/nccl/pynccl/__init__.py @@ -0,0 +1,8 @@ +from vllm.implementations.communicator.nccl.pynccl.pynccl_communicator import ( # noqa + NCCLCommunicator, get_pynccl_path, set_pynccl_path) + +__all__ = [ + "NCCCommunicator", + "get_pynccl_path", + "set_pynccl_path", +] diff --git a/vllm/implementations/communicator/nccl/pynccl/pynccl_communicator.py b/vllm/implementations/communicator/nccl/pynccl/pynccl_communicator.py new file mode 100644 index 0000000000000..c7e4aa723ce06 --- /dev/null +++ b/vllm/implementations/communicator/nccl/pynccl/pynccl_communicator.py @@ -0,0 +1,140 @@ +import logging +import os +from contextlib import contextmanager +from typing import Any, Optional + +import torch + +from vllm.implementations.communicator.nccl.pynccl.wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclDataType_t, ncclRedOp_t, + ncclUniqueId) +from vllm.interfaces.communicator import Communicator, ReduceOp +from vllm.interfaces.coordinator import Coordinator + +logger = logging.getLogger(__name__) + +# script to manage the path of the nccl library + +so_file: Optional[str] = None + + +def set_pynccl_path(path: str) -> None: + global so_file + so_file = path + + +def get_pynccl_path() -> Optional[str]: + return so_file + + +@contextmanager +def change_pynccl_path(path: str) -> None: + global so_file + old_path = so_file + so_file = path + yield + so_file = old_path + + +class NCCLCommunicator(Communicator): + + def __init__( + self, + coordinator: Coordinator, + path_of_nccl: str = None, + ): + if "CUDA_VISIBLE_DEVICES" in os.environ: + visible_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + assert len(visible_gpus) >= coordinator.get_group_size(), \ + (f"Number of visible GPUs {len(visible_gpus)} is less than" + f" the number of processes in the group {coordinator.group}.") + + super().__init__(coordinator) + + # search priority: + # 1. path_of_nccl (passed in the constructor of NCCLCommunicator) + # 2. so_file (set by users calling `set_pynccl_path`) + # 3. VLLM_NCCL_SO_PATH environment variable + # 4. default path + path_of_nccl = path_of_nccl or so_file or os.environ.get( + "VLLM_NCCL_SO_PATH", "") + if not path_of_nccl: + # not set yet, try a decent guess as default + if torch.version.cuda is not None: + path_of_nccl = "libnccl.so.2" + elif torch.version.hip is not None: + path_of_nccl = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.debug(f"Loading nccl from library {so_file}") + + try: + self.lib = NCCLLibrary(path_of_nccl) + except Exception as e: + logger.error( + f"Failed to load NCCL library from {path_of_nccl} ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise please set environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.") + raise e + + logger.info(f"vLLM is using nccl=={self.lib.ncclGetVersion()}") + local_rank = coordinator.get_local_rank() + torch.cuda.set_device(local_rank) + self.stream = torch.cuda.Stream(device=f"cuda:{local_rank}") + if coordinator.is_group_master(): + # get a unique id by calling nccl library + self.unique_id = self.lib.ncclGetUniqueId() + else: + # default initialization of unique_id + self.unique_id = ncclUniqueId() + data = bytearray(self.unique_id.internal) + coordinator.broadcast(data, src=coordinator.get_group_master_rank()) + for i in range(len(data)): + self.unique_id.internal[i] = data[i] + nrank = coordinator.get_group_size() + rank = coordinator.get_group_rank() + self.comm = self.lib.ncclCommInitRank(nrank, self.unique_id, rank) + + @staticmethod + def convert_reduce_op(op: ReduceOp) -> ncclRedOp_t: + return { + ReduceOp.SUM: ncclRedOp_t.ncclSum, + ReduceOp.PRODUCT: ncclRedOp_t.ncclProd, + ReduceOp.MAX: ncclRedOp_t.ncclMax, + ReduceOp.MIN: ncclRedOp_t.ncclMin, + ReduceOp.AVG: ncclRedOp_t.ncclAvg, + }[op] + + @staticmethod + def convert_data_type(dtype: torch.dtype) -> ncclDataType_t: + return { + torch.int8: ncclDataType_t.ncclInt8, + torch.uint8: ncclDataType_t.ncclUint8, + torch.int32: ncclDataType_t.ncclInt32, + torch.int64: ncclDataType_t.ncclInt64, + torch.float16: ncclDataType_t.ncclFloat16, + torch.float32: ncclDataType_t.ncclFloat32, + torch.float64: ncclDataType_t.ncclFloat64, + torch.bfloat16: ncclDataType_t.ncclBfloat16, + }[dtype] + + def all_reduce(self, + tensor_in: torch.Tensor, + tensor_out: Optional[torch.Tensor] = None, + op: ReduceOp = ReduceOp.SUM, + stream: Optional[Any] = None): + assert tensor_in.is_cuda and tensor_in.is_contiguous() + if tensor_out is None: + tensor_out = tensor_in + op = self.convert_reduce_op(op) + dtype = self.convert_data_type(tensor_in.dtype) + if stream is None: + stream = self.stream + self.lib.ncclAllReduce(buffer_type(tensor_in.data_ptr()), + buffer_type(tensor_out.data_ptr()), + tensor_in.numel(), dtype, op, self.comm, + cudaStream_t(stream.cuda_stream)) + + def __del__(self): + self.lib.ncclCommDestroy(self.comm) diff --git a/vllm/implementations/communicator/nccl/pynccl/wrapper.py b/vllm/implementations/communicator/nccl/pynccl/wrapper.py new file mode 100644 index 0000000000000..0e7e7d10f74b9 --- /dev/null +++ b/vllm/implementations/communicator/nccl/pynccl/wrapper.py @@ -0,0 +1,157 @@ +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +from dataclasses import dataclass +from typing import Any, List + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + + +# enums +class ncclDataType_t(ctypes.c_int): + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + +class ncclRedOp_t(ctypes.c_int): + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + def __init__(self, so_file: str): + self.lib = ctypes.CDLL(so_file) + self._funcs = {} + for func in self.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + self._funcs[func.name] = f + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + result = self._funcs["ncclGetVersion"](ctypes.byref(version)) + assert result == 0 + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + result = self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id)) + assert result == 0 + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + result = self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, rank) + assert result == 0 + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: ncclDataType_t, op: ncclRedOp_t, + comm: ncclComm_t, stream: cudaStream_t) -> None: + result = self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, stream) + assert result == 0 + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + result = self._funcs["ncclCommDestroy"](comm) + assert result == 0 + + +__all__ = [ + "NCCLLibrary", "ncclDataType_t", "ncclRedOp_t", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/vllm/implementations/coordinator/__init__.py b/vllm/implementations/coordinator/__init__.py new file mode 100644 index 0000000000000..add6906fd2bde --- /dev/null +++ b/vllm/implementations/coordinator/__init__.py @@ -0,0 +1,16 @@ +from enum import Enum, auto + + +class CoordinatorType(Enum): + TORCH_DISTRIBUTED = auto() + + +def get_coordinator_class(coordinator_type: CoordinatorType) -> type: + # lazy init + # only import the coordinator when it is needed + if coordinator_type == CoordinatorType.TORCH_DISTRIBUTED: + from vllm.implementations.coordinator.torch_distributed.torch_distributed_coordinator import ( # noqa + TorchDistributedCoordinator) + return TorchDistributedCoordinator + else: + raise ValueError(f"Coordinator type {coordinator_type} not regonized.") diff --git a/vllm/implementations/coordinator/torch_distributed/torch_distributed_coordinator.py b/vllm/implementations/coordinator/torch_distributed/torch_distributed_coordinator.py new file mode 100644 index 0000000000000..1c7d18e23b2a3 --- /dev/null +++ b/vllm/implementations/coordinator/torch_distributed/torch_distributed_coordinator.py @@ -0,0 +1,75 @@ +# Implementation of the Coordinator interface based on +# PyTorch's distributed package. + +import os +from typing import List, Optional + +import torch +import torch.distributed as dist + +from vllm.interfaces.coordinator import Coordinator + + +class TorchDistributedCoordinator(Coordinator): + + def __init__(self, groups: Optional[List[List[int]]] = None): + assert 'RANK' in os.environ, \ + 'RANK not found in environment' + assert 'WORLD_SIZE' in os.environ, \ + 'WORLD_SIZE not found in environment' + assert 'LOCAL_RANK' in os.environ, \ + 'LOCAL_RANK not found in environment' + assert 'LOCAL_WORLD_SIZE' in os.environ, \ + 'LOCAL_WORLD_SIZE not found in environment' + assert 'MASTER_ADDR' in os.environ, \ + 'MASTER_ADDR not found in environment' + assert 'MASTER_PORT' in os.environ, \ + 'MASTER_PORT not found in environment' + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + local_rank = int(os.environ['LOCAL_RANK']) + local_world_size = int(os.environ['LOCAL_WORLD_SIZE']) + groups = groups or [list(range(world_size))] + super().__init__(rank=rank, + world_size=world_size, + local_rank=local_rank, + local_world_size=local_world_size, + groups=groups) + self.process_group = None + + def initialize(self): + # in `torch.distributed`, we can only initialize the process group once + # if the process group is already initialized, we should not initialize + # it again, but need to use `new_group` to create a new group. + # in either case, `self.group` contains all the processes. It's just we + # use `gloo` backend ourselves inside this coordinator, + # to avoid interfering with other process groups. + if not dist.is_initialized(): + dist.init_process_group(backend='gloo') + self.process_group = dist.group.WORLD + else: + # each time we create a new group, we need **all the processes** + # to call `new_group`. For example, when we have 4 processes, + # and we want to create two groups, [0, 1] and [2, 3], + # 1. [0, 1, 2, 3] should call `new_group` with ranks=[0, 1] + # 2. [0, 1, 2, 3] should call `new_group` with ranks=[2, 3] + for group in self.groups: + result_group = dist.new_group(ranks=group, backend='gloo') + if self.rank in group: + self.process_group = result_group + super().initialize() + + def barrier(self): + dist.barrier(group=self.process_group) + + def broadcast(self, message: bytearray, src: int = 0): + tensor = torch.tensor(list(message), dtype=torch.uint8) + dist.broadcast(tensor, src=src, group=self.process_group) + data = tensor.tolist() + for i in range(len(message)): + message[i] = data[i] + + def __del__(self): + # `dist` module might have been already destroyed + if hasattr(dist, 'destroy_process_group'): + dist.destroy_process_group(self.process_group) diff --git a/vllm/implementations/distributed_tasks/__init__.py b/vllm/implementations/distributed_tasks/__init__.py new file mode 100644 index 0000000000000..31b73bbd4b8ff --- /dev/null +++ b/vllm/implementations/distributed_tasks/__init__.py @@ -0,0 +1,8 @@ +from vllm.implementations.distributed_tasks.global_coordinator_task import ( + GlobalCoordinatorDistributedTask) +from vllm.implementations.distributed_tasks.group_coordinator_task import ( + GroupCoordinatorDistributedTask) + +__all__ = [ + 'GlobalCoordinatorDistributedTask', 'GroupCoordinatorDistributedTask' +] diff --git a/vllm/implementations/distributed_tasks/global_coordinator_task.py b/vllm/implementations/distributed_tasks/global_coordinator_task.py new file mode 100644 index 0000000000000..27c1264290d23 --- /dev/null +++ b/vllm/implementations/distributed_tasks/global_coordinator_task.py @@ -0,0 +1,27 @@ +from vllm.implementations.communicator import (CommunicatorType, + get_communicator_class) +from vllm.implementations.coordinator import (CoordinatorType, + get_coordinator_class) +from vllm.interfaces.communicator import Communicator +from vllm.interfaces.coordinator import Coordinator +from vllm.interfaces.launcher import DistributedTask + + +class GlobalCoordinatorDistributedTask(DistributedTask): + + def run(self, *, coordinator_type: CoordinatorType, + communicator_type: CommunicatorType, **kwargs): + coordinator_cls = get_coordinator_class(coordinator_type) + communicator_cls = get_communicator_class(communicator_type) + self.coordinator: Coordinator = coordinator_cls() + self.coordinator.initialize() + self.communicator: Communicator = communicator_cls(self.coordinator) + self.post_init_distributed(**kwargs) + + def post_init_distributed(self, **kwargs): + """Subclasses can override this method to do whatever they want. + They can use `self.coordinator` for global communication over the + whole process group. + They can use `self.communicator` for communication between devices. + """ + return diff --git a/vllm/implementations/distributed_tasks/group_coordinator_task.py b/vllm/implementations/distributed_tasks/group_coordinator_task.py new file mode 100644 index 0000000000000..26eca330a564e --- /dev/null +++ b/vllm/implementations/distributed_tasks/group_coordinator_task.py @@ -0,0 +1,38 @@ +from typing import List + +from vllm.implementations.communicator import (CommunicatorType, + get_communicator_class) +from vllm.implementations.coordinator import (CoordinatorType, + get_coordinator_class) +from vllm.interfaces.communicator import Communicator +from vllm.interfaces.coordinator import Coordinator +from vllm.interfaces.launcher import DistributedTask + + +class GroupCoordinatorDistributedTask(DistributedTask): + + def run(self, *, coordinator_type: CoordinatorType, + communicator_type: CommunicatorType, groups: List[List[int]], + **kwargs): + coordinator_cls = get_coordinator_class(coordinator_type) + communicator_cls = get_communicator_class(communicator_type) + self.global_coordinator: Coordinator = coordinator_cls() + self.global_coordinator.initialize() + + self.group_coordinator: Coordinator = coordinator_cls(groups) + self.group_coordinator.initialize() + + self.communicator: Communicator = communicator_cls( + self.group_coordinator) + self.post_init_distributed(**kwargs) + + def post_init_distributed(self, **kwargs): + """Subclasses can override this method to do whatever they want. + They can use `self.global_coordinator` for global communication + over the whole process group. + They can use `self.group_coordinator` for communication within a + subgroup. + They can use `self.communicator` for communication between devices + in a subgroup. + """ + return diff --git a/vllm/implementations/launcher/__init__.py b/vllm/implementations/launcher/__init__.py new file mode 100644 index 0000000000000..a50a65c35798a --- /dev/null +++ b/vllm/implementations/launcher/__init__.py @@ -0,0 +1,14 @@ +from enum import Enum, auto + + +class LauncherType(Enum): + MPLauncher = auto() + + +def get_launcher_class(launcher_type: LauncherType): + if launcher_type == LauncherType.MPLauncher: + from vllm.implementations.launcher.mp_launcher import MPLauncher + return MPLauncher + else: + raise NotImplementedError( + f"Launcher type {launcher_type} not implemented") diff --git a/vllm/implementations/launcher/mp_launcher.py b/vllm/implementations/launcher/mp_launcher.py new file mode 100644 index 0000000000000..0089dae719ba5 --- /dev/null +++ b/vllm/implementations/launcher/mp_launcher.py @@ -0,0 +1,40 @@ +import uuid +from multiprocessing import Process + +from vllm.interfaces.launcher import Launcher, SubClassOfDistributedTask +from vllm.utils import get_open_port + + +class MPLauncher(Launcher): + # this is intended to work in single node + def __init__(self, n_tasks: int): + self.n_tasks = n_tasks + + def launch(self, *, task_type: SubClassOfDistributedTask, **kwargs): + # be cautious that `kwargs` might well be serialized + # and deserialized before being passed to tasks + launch_id = str(uuid.uuid4()) + envs = [{} for _ in range(self.n_tasks)] + port = str(get_open_port()) + for i, env in enumerate(envs): + env['LAUNCH_ID'] = launch_id + env['WORLD_SIZE'] = str(self.n_tasks) + env['RANK'] = str(i) + env['LOCAL_WORLD_SIZE'] = str(self.n_tasks) + env['LOCAL_RANK'] = str(i) + env['MASTER_ADDR'] = 'localhost' + env['MASTER_PORT'] = port + tasks = [] + for i in range(self.n_tasks): + p = Process(target=task_type, args=(envs[i], (), kwargs)) + p.start() + tasks.append(p) + for task in tasks: + task.join() + msg = "" + for i, task in enumerate(tasks): + if task.exitcode != 0: + msg += f"Task {i} exited with code {task.exitcode}" + # if no error, `msg` should be empty + # if there is an error, `msg` should contain the error message + assert msg == "", msg diff --git a/vllm/interfaces/communicator.py b/vllm/interfaces/communicator.py new file mode 100644 index 0000000000000..d65e25ab4b387 --- /dev/null +++ b/vllm/interfaces/communicator.py @@ -0,0 +1,69 @@ +# communicator interface, as proposed in +# https://github.com/vllm-project/vllm/issues/3587 +# `Communicator` is responsible for communicating **large tensor data** +# between multiple devices. This functionality is usually provided by +# vendors, e.g. NCCL from NVIDIA, RCCL from AMD. +# Put it simple, this is for data-plane communication. + +from typing import Any, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.interfaces.coordinator import Coordinator + + +class Communicator(object): + """ + `coordinator` is the object used to initialize the communicator. + `group` is the list of **global** ranks to identify communication groups. + For functions with a `src` or `dst` argument, the rank is also global. + If the communicator needs to know the local rank inside a group, it should + convert the rank by searching over the group. + + The interfaces are designed to be as general as possible. They contain + out-of-place operations, and stream arguments to launch the operations. + However, subclasses are free to implement only in-place operations without + the stream argument, i.e. raising NotImplementedError for the out-of-place + operations and not-None stream argument, as long as they satisfy the + requirements of the corresponding application. + """ + + def __init__(self, coordinator: Coordinator): + self.coordinator = coordinator + assert self.coordinator.is_initialized() + + def broadcast(self, + tensor_in: torch.Tensor, + tensor_out: Optional[torch.Tensor] = None, + src: int = 0, + stream: Optional[Any] = None): + raise NotImplementedError( + f"broadcast is not implemented in {self.__class__.__name__}") + + def all_reduce(self, + tensor_in: torch.Tensor, + tensor_out: Optional[torch.Tensor] = None, + op: ReduceOp = ReduceOp.SUM, + stream: Optional[Any] = None): + raise NotImplementedError( + f"all_reduce is not implemented in {self.__class__.__name__}") + + def reduce(self, + tensor_in: torch.Tensor, + tensor_out: Optional[torch.Tensor] = None, + dst: int = 0, + op: ReduceOp = ReduceOp.SUM, + stream: Optional[Any] = None): + raise NotImplementedError( + f"reduce is not implemented in {self.__class__.__name__}") + + def all_gather(self, + tensor_in: torch.Tensor, + tensor_out: Optional[torch.Tensor] = None, + stream: Optional[Any] = None): + raise NotImplementedError( + f"all_gather is not implemented in {self.__class__.__name__}") + + def __del__(self): + pass diff --git a/vllm/interfaces/coordinator.py b/vllm/interfaces/coordinator.py new file mode 100644 index 0000000000000..fc64b353278d8 --- /dev/null +++ b/vllm/interfaces/coordinator.py @@ -0,0 +1,121 @@ +# coordinator interface, as proposed in +# https://github.com/vllm-project/vllm/issues/3587 +# `Coordinator` is responsible for communicating **tiny control messages** +# between multiple processes. This functionality is usually provided by +# PyTorch (gloo backend) or MPI, implemented using CPU. +# Put it simple, this is for control-plane communication. + +from abc import ABC, abstractmethod +from typing import List + + +class Coordinator(ABC): + """This is the abstract interface for the coordinator. + The least requirement for the coordinator is to provide: + 1. The world size of the distributed environment. + 2. The rank of the current process inside the world. + 3. The local rank of the current process inside the node. + 4. The local world size inside the node. + 5. The current group the process belongs to. + + Note that if the `group` is provided, the coordinator should only + synchronize the processes in the group. If you want to not only + coordinate **all** the processes but also coordinate **subgroups** + of processes, you can create multiple coordinators. In that case, + the coordinator with only one group must be initialized first. + + To avoid confusion in argument passing, all arguments are set + to be keyword-only. + + Usually subclasses need to implement the following methods: + 1. `__init__`: Initialize the coordinator, only set the necessary + attributes with sanity checks. + 2. `initialize`: Initialize the coordinator. This is set to be a + separate method for lazy initialization. In addition, subclasses + should call this method after their `initialize` method. + 3. `barrier`: Synchronize all the processes. + 4. `broadcast`: Broadcast a message from the source process to all + other processes. + """ + + def __init__(self, *, rank: int, world_size: int, local_rank: int, + local_world_size: int, groups: List[List[int]], **kwargs): + self.rank = rank + self.world_size = world_size + self.local_rank = local_rank + self.local_world_size = local_world_size + self._initialize = False + self.groups = groups + self.group = [g for g in groups if self.rank in g][0] + + def initialize(self): + """Initialize the coordinator. This is set to be a separate method + so that the coordinator can be initialized after the object is created. + + This method is supposed to be called by all the participating processes. + """ + self._initialize = True + + def is_initialized(self) -> bool: + """Check if the coordinator has been initialized.""" + return self._initialize + + def get_world_size(self) -> int: + """Get the world size of the distributed environment.""" + return self.world_size + + def get_rank(self) -> int: + """Get the rank of the current process inside the world.""" + return self.rank + + def is_master(self) -> bool: + """Check if the current process is the master process.""" + return self.rank == 0 + + def get_local_world_size(self) -> int: + """Get the local world size inside the node.""" + return self.local_world_size + + def get_local_rank(self) -> int: + """Get the local rank of the current process inside the node.""" + return self.local_rank + + def is_local_master(self) -> bool: + """Check if the current process is the local master process.""" + return self.local_rank == 0 + + def get_group_size(self) -> int: + """Get the size of the group.""" + return len(self.group) + + def get_group_rank(self) -> int: + """Get the rank of the current process inside the group.""" + return self.group.index(self.rank) + + def is_group_master(self) -> bool: + """Check if the current process is the group master process.""" + return self.get_group_rank() == 0 + + def get_group_master_rank(self) -> int: + """Get the rank of the group master process.""" + return self.group[0] + + @abstractmethod + def barrier(self): + """Synchronize all the processes.""" + raise NotImplementedError + + @abstractmethod + def broadcast(self, message: bytearray, src: int = 0) -> None: + """Broadcast a message from the source process to all other processes. + Note that the message type is explicitly set to `bytearray`, to + indicate that this is a tiny control message. + + Note: this is an in-place operation, the message is modified in-place. + """ + raise NotImplementedError + + @abstractmethod + def __del__(self): + """Release the resources.""" + raise NotImplementedError diff --git a/vllm/interfaces/launcher.py b/vllm/interfaces/launcher.py new file mode 100644 index 0000000000000..4c5e7648a29da --- /dev/null +++ b/vllm/interfaces/launcher.py @@ -0,0 +1,60 @@ +# launcher interface, as proposed in +# https://github.com/vllm-project/vllm/issues/3587 +# `Launcher` is responsible for creating workers. + +import os +import warnings +from abc import ABC, abstractmethod +from typing import Dict, Type, TypeVar + + +class DistributedTask(ABC): + + def __init__(self, env: Dict[str, str], args, kwargs): + self.update_env(env) + self.run(*args, **kwargs) + + def update_env(self, env: Dict[str, str]): + for k, v in env.items(): + if k in os.environ: + warnings.warn( + f"Overwriting environment variable {k} " + f"from {os.environ[k]} to {v}", + stacklevel=2) + os.environ[k] = v + + @abstractmethod + def run(self, *args, **kwargs): + # usually: + # initialize coordinator and communicator + # initialize device + # initialize model + # warmup model + # run model + pass + + +T = TypeVar('T', bound=DistributedTask) +SubClassOfDistributedTask = Type[T] + + +class Launcher(ABC): + + @abstractmethod + def launch(self, *, task_type: SubClassOfDistributedTask, **kwargs): + # only keyword arguments are allowed, to avoid confusion + + # this is a dunmmy implementation, but captures the idea + # 1. prepare environment variables, args, kwargs for each task + n_tasks = 4 + envs = [{} for _ in range(n_tasks)] + # 2. create tasks (typically these tasks should be run in parallel) + # note that creating a task will also run it. This is designed for + # simple launcher like multiprocessing, where we can only pass a + # function to run, and cannot do any further operations on the task. + + # this parameter-passing happens across processes, and we use three + # args to pass the envs, args, and kwargs. + for env, arg, kwarg in zip(envs, (), kwargs): + task_type(env, arg, kwarg) + # 3. wait for tasks to finish