Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Splitwise implementation to vLLM #2809

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions csrc/kv_comm_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <mscclpp/proxy_channel_device.hpp>

extern "C" __global__ void __launch_bounds__(1024, 1)
nw_cache_in_kernel(mscclpp::ProxyChannelDeviceHandle* proxyChannel) {
int globalIndex = blockIdx.x * blockDim.x + threadIdx.x;
if (globalIndex == 0) {
proxyChannel[0].wait(100000000);
}
}

extern "C" __global__ void __launch_bounds__(1024, 1)
nw_cache_out_kernel(mscclpp::ProxyChannelDeviceHandle* proxyChannel, int dst_mem, int src_mem, int kv_block_offset, int dataSize, int flush) {
int globalIndex = blockIdx.x * blockDim.x + threadIdx.x;
if (globalIndex == 0) {
proxyChannel[0].put(dst_mem, kv_block_offset, src_mem, kv_block_offset, dataSize);
if (flush) {
proxyChannel[0].flush();
}
}
}

extern "C" __global__ void __launch_bounds__(1024, 1)
nw_cache_out_signal_kernel(mscclpp::ProxyChannelDeviceHandle* proxyChannel) {
int globalIndex = blockIdx.x * blockDim.x + threadIdx.x;
if (globalIndex == 0) {
proxyChannel[0].signal();
proxyChannel[0].flush();
}
}
7 changes: 7 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ Documentation

quantization/auto_awq

.. toctree::
:maxdepth: 1
:caption: Splitwise

splitwise/getting_started
splitwise/installing_mscclpp

.. toctree::
:maxdepth: 2
:caption: Developer Documentation
Expand Down
16 changes: 16 additions & 0 deletions docs/source/splitwise/getting_started.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
.. _getting_started:

Getting Started with Splitwise
==============================

`Splitwise <https://www.microsoft.com/en-us/research/publication/splitwise-efficient-generative-llm-inference-using-phase-splitting/>`_ is a technique to split the two phases of an LLM inference request - prompt processing and token generation - on to separate machines for efficient inference.

Installing MSCCL++
-------------------------

Please follow :ref:`MSCCL++ installation instructions <installing_mscclpp>` to install the MSCCL++ communication library used for implementing the communication of KV caches from prompt to token workers.

Running inference with Splitwise
--------------------------------

Simply add ``--sep-prompt-token`` flag to the vLLM command in order to use Splitwise.
18 changes: 18 additions & 0 deletions docs/source/splitwise/installing_mscclpp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.. _installing_mscclpp:

Installing MSCCL++
============================

`MSCCL++ <https://github.com/microsoft/mscclpp>`_ is a GPU-driven communication stack for scalable AI applications.
It is used to implement KV cache communication in Splitwise.

To install MSCCL++, please follow the instructions at `MSCCL++ Quickstart <https://github.com/microsoft/mscclpp/blob/main/docs/quickstart.md>`_ or follow the steps below to install it from source:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a heads that that the system may need to apt install libnuma-dev (libnuma1) prior make (I hit this error at installation).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up. I will add this to the installation instructions both here and in the MSCCL++ repo. Also, feel free to reach out to me if you have any other problems with MSCCL++ setup.

MSCCL++ required libnuma, which can be installed using `apt install libnuma-dev` on Debian-based systems.

.. code-block:: console

$ git clone https://github.com/microsoft/mscclpp;
$ mkdir mscclpp/build; cd mscclpp/build; cmake -DCMAKE_BUILD_TYPE=Release ..; make -j;
$ conda install -c conda-forge mpi4py
$ cd ../python; pip install -r requirements_c12.txt;
$ cd ..; pip install -e .
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ pydantic >= 2.0 # Required for OpenAI server.
aioprometheus[starlette]
pynvml == 11.5.0
triton >= 2.1.0
gputil
42 changes: 42 additions & 0 deletions tests/distributed/test_kvcache_comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Test the KV cache communication operators.

Run `python test_kvcache_comm.py`.
"""
import argparse
import ray

from vllm import EngineArgs, LLMEngine


def initialize_engine(args: argparse.Namespace) -> LLMEngine:
"""Initialize the LLMEngine from the command line arguments."""
engine_args = EngineArgs.from_cli_args(args)
return LLMEngine.from_engine_args(engine_args)


def run_all_workers(engine: LLMEngine, method: str, *args):
"""Run all the workers."""
ray_worker_outputs = [
worker.execute_method.remote(method, *args)
for worker in engine.workers
]
_ = getattr(engine.driver_worker, method)(*args)
ray.get(ray_worker_outputs)


"""Test the kv cache communication."""
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Demo on using the LLMEngine class directly')
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
args.model = "meta-llama/Llama-2-70b-hf"
args.tensor_parallel_size = 2
args.sep_prompt_token = True
engine = initialize_engine(args)

run_all_workers(engine, "set_gpucache")
run_all_workers(engine, "send_recv_kvcache_all")
run_all_workers(engine, "check_gpucache")

engine.destroy_kvcache_comm()
8 changes: 8 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,22 @@ def __init__(
worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
sep_prompt_token: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modification on config looks good to me. I guess there may be future extension to pass in the number of prompt / token workers, and I think so far the abstraction looks good.

) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.sep_prompt_token = sep_prompt_token

self.world_size = pipeline_parallel_size * tensor_parallel_size
if sep_prompt_token:
# Half of the workers are prompt workers and the other half are token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shout it be the fixed size? If the workers use exact same gpus. seems prompt may need more works?

self.num_prompt_workers = self.world_size
self.num_token_workers = self.world_size
self.world_size = self.num_prompt_workers + self.num_token_workers

if self.world_size > 1:
self.worker_use_ray = True
self._verify_args()
Expand Down
42 changes: 39 additions & 3 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import deque
from collections import defaultdict, deque
import enum
import time
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set
Expand All @@ -11,6 +11,7 @@
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from vllm.prefix import PrefixPool
from vllm.utils import SeqToSlotMapper, coalesce_blocks_by_id

logger = init_logger(__name__)

Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
blocks_to_nw: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup],
) -> None:
self.scheduled_seq_groups = scheduled_seq_groups
Expand All @@ -46,6 +48,7 @@ def __init__(
self.blocks_to_swap_in = blocks_to_swap_in
self.blocks_to_swap_out = blocks_to_swap_out
self.blocks_to_copy = blocks_to_copy
self.blocks_to_nw = blocks_to_nw
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
Expand All @@ -57,7 +60,8 @@ def __init__(
def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)
and not self.blocks_to_swap_out and not self.blocks_to_copy
and not self.blocks_to_nw)

def _sort_by_lora_ids(self) -> bool:
self.scheduled_seq_groups = sorted(
Expand All @@ -77,13 +81,18 @@ def __init__(
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
track_prompt_blocks: bool = False,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
# simple and NOT fair. It can lead to starvation of some
# LoRAs. This should be improved in the future.
self.lora_config = lora_config
self.track_prompt_blocks = track_prompt_blocks
self.seq_to_slot_mapper: Optional[SeqToSlotMapper] = None
if track_prompt_blocks:
self.seq_to_slot_mapper = SeqToSlotMapper()

self.prompt_limit = min(self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
Expand Down Expand Up @@ -158,10 +167,11 @@ def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)

def _schedule(self) -> SchedulerOutputs:
# Blocks that need to be swaped or copied before model execution.
# Blocks that need to be swaped, copied, or networked before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
blocks_to_nw: Dict[int, List[int]] = defaultdict(list)

# Fix the current time.
now = time.monotonic()
Expand Down Expand Up @@ -252,6 +262,15 @@ def _schedule(self) -> SchedulerOutputs:
self.running.append(seq_group)
num_curr_seqs += num_new_seqs
scheduled.append(seq_group)
if self.track_prompt_blocks:
for seq in seq_group.get_seqs():
# Populate blocks_to_nw for the sequences in prompt phase
# and first step of generation phase
if seq.get_output_len() <= 1:
block_ids = self.block_manager.get_block_table(seq)
slot_id = self.seq_to_slot_mapper.get_slot_id(
seq.seq_id)
blocks_to_nw[slot_id].extend(block_ids)

self.waiting.extendleft(leftover_waiting_sequences)

Expand All @@ -264,6 +283,7 @@ def _schedule(self) -> SchedulerOutputs:
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
blocks_to_nw=coalesce_blocks_by_id(blocks_to_nw),
ignored_seq_groups=ignored_seq_groups,
)
return scheduler_outputs
Expand Down Expand Up @@ -349,13 +369,25 @@ def _schedule(self) -> SchedulerOutputs:
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)

if self.track_prompt_blocks:
for seq_group in self.running:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Populate blocks_to_nw for the sequences in prompt phase
# and first step of generation phase
if seq.get_output_len() <= 1:
block_ids = self.block_manager.get_block_table(seq)
slot_id = self.seq_to_slot_mapper.get_slot_id(
seq.seq_id)
blocks_to_nw[slot_id].extend(block_ids)

scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=self.running,
prompt_run=False,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
blocks_to_nw=coalesce_blocks_by_id(blocks_to_nw),
ignored_seq_groups=[],
)
return scheduler_outputs
Expand Down Expand Up @@ -393,6 +425,8 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:

def free_seq(self, seq: Sequence) -> None:
self.block_manager.free(seq)
if self.seq_to_slot_mapper is not None:
self.seq_to_slot_mapper.free_seq(seq.seq_id)

def free_finished_seq_groups(self) -> None:
self.running = deque(seq_group for seq_group in self.running
Expand All @@ -402,6 +436,8 @@ def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING
if self.seq_to_slot_mapper is not None:
self.seq_to_slot_mapper.set_seq(seq.seq_id)

def _append_slot(
self,
Expand Down
8 changes: 7 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class EngineArgs:
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
sep_prompt_token: bool = False
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
Expand Down Expand Up @@ -159,6 +160,10 @@ def add_cli_args(
help='load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor '
'parallel and large models')
parser.add_argument(
'--sep-prompt-token',
action='store_true',
help='separate the prompt processing and token sampling')
# KV cache arguments
parser.add_argument('--block-size',
type=int,
Expand Down Expand Up @@ -294,7 +299,8 @@ def create_engine_configs(
self.tensor_parallel_size,
self.worker_use_ray,
self.max_parallel_loading_workers,
self.disable_custom_all_reduce)
self.disable_custom_all_reduce,
self.sep_prompt_token)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
Expand Down
Loading