-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Splitwise implementation to vLLM #2809
Changes from all commits
0fef346
8e23ffb
5f7b9d5
2fada96
e910eb7
4811486
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#include <mscclpp/proxy_channel_device.hpp> | ||
|
||
extern "C" __global__ void __launch_bounds__(1024, 1) | ||
nw_cache_in_kernel(mscclpp::ProxyChannelDeviceHandle* proxyChannel) { | ||
int globalIndex = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (globalIndex == 0) { | ||
proxyChannel[0].wait(100000000); | ||
} | ||
} | ||
|
||
extern "C" __global__ void __launch_bounds__(1024, 1) | ||
nw_cache_out_kernel(mscclpp::ProxyChannelDeviceHandle* proxyChannel, int dst_mem, int src_mem, int kv_block_offset, int dataSize, int flush) { | ||
int globalIndex = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (globalIndex == 0) { | ||
proxyChannel[0].put(dst_mem, kv_block_offset, src_mem, kv_block_offset, dataSize); | ||
if (flush) { | ||
proxyChannel[0].flush(); | ||
} | ||
} | ||
} | ||
|
||
extern "C" __global__ void __launch_bounds__(1024, 1) | ||
nw_cache_out_signal_kernel(mscclpp::ProxyChannelDeviceHandle* proxyChannel) { | ||
int globalIndex = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (globalIndex == 0) { | ||
proxyChannel[0].signal(); | ||
proxyChannel[0].flush(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
.. _getting_started: | ||
|
||
Getting Started with Splitwise | ||
============================== | ||
|
||
`Splitwise <https://www.microsoft.com/en-us/research/publication/splitwise-efficient-generative-llm-inference-using-phase-splitting/>`_ is a technique to split the two phases of an LLM inference request - prompt processing and token generation - on to separate machines for efficient inference. | ||
|
||
Installing MSCCL++ | ||
------------------------- | ||
|
||
Please follow :ref:`MSCCL++ installation instructions <installing_mscclpp>` to install the MSCCL++ communication library used for implementing the communication of KV caches from prompt to token workers. | ||
|
||
Running inference with Splitwise | ||
-------------------------------- | ||
|
||
Simply add ``--sep-prompt-token`` flag to the vLLM command in order to use Splitwise. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
.. _installing_mscclpp: | ||
|
||
Installing MSCCL++ | ||
============================ | ||
|
||
`MSCCL++ <https://github.com/microsoft/mscclpp>`_ is a GPU-driven communication stack for scalable AI applications. | ||
It is used to implement KV cache communication in Splitwise. | ||
|
||
To install MSCCL++, please follow the instructions at `MSCCL++ Quickstart <https://github.com/microsoft/mscclpp/blob/main/docs/quickstart.md>`_ or follow the steps below to install it from source: | ||
MSCCL++ required libnuma, which can be installed using `apt install libnuma-dev` on Debian-based systems. | ||
|
||
.. code-block:: console | ||
|
||
$ git clone https://github.com/microsoft/mscclpp; | ||
$ mkdir mscclpp/build; cd mscclpp/build; cmake -DCMAKE_BUILD_TYPE=Release ..; make -j; | ||
$ conda install -c conda-forge mpi4py | ||
$ cd ../python; pip install -r requirements_c12.txt; | ||
$ cd ..; pip install -e . |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,4 @@ pydantic >= 2.0 # Required for OpenAI server. | |
aioprometheus[starlette] | ||
pynvml == 11.5.0 | ||
triton >= 2.1.0 | ||
gputil |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
"""Test the KV cache communication operators. | ||
|
||
Run `python test_kvcache_comm.py`. | ||
""" | ||
import argparse | ||
import ray | ||
|
||
from vllm import EngineArgs, LLMEngine | ||
|
||
|
||
def initialize_engine(args: argparse.Namespace) -> LLMEngine: | ||
"""Initialize the LLMEngine from the command line arguments.""" | ||
engine_args = EngineArgs.from_cli_args(args) | ||
return LLMEngine.from_engine_args(engine_args) | ||
|
||
|
||
def run_all_workers(engine: LLMEngine, method: str, *args): | ||
"""Run all the workers.""" | ||
ray_worker_outputs = [ | ||
worker.execute_method.remote(method, *args) | ||
for worker in engine.workers | ||
] | ||
_ = getattr(engine.driver_worker, method)(*args) | ||
ray.get(ray_worker_outputs) | ||
|
||
|
||
"""Test the kv cache communication.""" | ||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description='Demo on using the LLMEngine class directly') | ||
parser = EngineArgs.add_cli_args(parser) | ||
args = parser.parse_args() | ||
args.model = "meta-llama/Llama-2-70b-hf" | ||
args.tensor_parallel_size = 2 | ||
args.sep_prompt_token = True | ||
engine = initialize_engine(args) | ||
|
||
run_all_workers(engine, "set_gpucache") | ||
run_all_workers(engine, "send_recv_kvcache_all") | ||
run_all_workers(engine, "check_gpucache") | ||
|
||
engine.destroy_kvcache_comm() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -369,14 +369,22 @@ def __init__( | |
worker_use_ray: bool, | ||
max_parallel_loading_workers: Optional[int] = None, | ||
disable_custom_all_reduce: bool = False, | ||
sep_prompt_token: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modification on config looks good to me. I guess there may be future extension to pass in the number of prompt / token workers, and I think so far the abstraction looks good. |
||
) -> None: | ||
self.pipeline_parallel_size = pipeline_parallel_size | ||
self.tensor_parallel_size = tensor_parallel_size | ||
self.worker_use_ray = worker_use_ray | ||
self.max_parallel_loading_workers = max_parallel_loading_workers | ||
self.disable_custom_all_reduce = disable_custom_all_reduce | ||
self.sep_prompt_token = sep_prompt_token | ||
|
||
self.world_size = pipeline_parallel_size * tensor_parallel_size | ||
if sep_prompt_token: | ||
# Half of the workers are prompt workers and the other half are token | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shout it be the fixed size? If the workers use exact same gpus. seems prompt may need more works? |
||
self.num_prompt_workers = self.world_size | ||
self.num_token_workers = self.world_size | ||
self.world_size = self.num_prompt_workers + self.num_token_workers | ||
|
||
if self.world_size > 1: | ||
self.worker_use_ray = True | ||
self._verify_args() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a heads that that the system may need to
apt install libnuma-dev
(libnuma1
) priormake
(I hit this error at installation).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the heads up. I will add this to the installation instructions both here and in the MSCCL++ repo. Also, feel free to reach out to me if you have any other problems with MSCCL++ setup.