diff --git a/csrc/kv_comm_kernels.cu b/csrc/kv_comm_kernels.cu new file mode 100644 index 0000000000000..da9b69d3d2b2e --- /dev/null +++ b/csrc/kv_comm_kernels.cu @@ -0,0 +1,29 @@ +#include + +extern "C" __global__ void __launch_bounds__(1024, 1) + nw_cache_in_kernel(mscclpp::ProxyChannelDeviceHandle* proxyChannel) { + int globalIndex = blockIdx.x * blockDim.x + threadIdx.x; + if (globalIndex == 0) { + proxyChannel[0].wait(100000000); + } +} + +extern "C" __global__ void __launch_bounds__(1024, 1) + nw_cache_out_kernel(mscclpp::ProxyChannelDeviceHandle* proxyChannel, int dst_mem, int src_mem, int kv_block_offset, int dataSize, int flush) { + int globalIndex = blockIdx.x * blockDim.x + threadIdx.x; + if (globalIndex == 0) { + proxyChannel[0].put(dst_mem, kv_block_offset, src_mem, kv_block_offset, dataSize); + if (flush) { + proxyChannel[0].flush(); + } + } +} + +extern "C" __global__ void __launch_bounds__(1024, 1) + nw_cache_out_signal_kernel(mscclpp::ProxyChannelDeviceHandle* proxyChannel) { + int globalIndex = blockIdx.x * blockDim.x + threadIdx.x; + if (globalIndex == 0) { + proxyChannel[0].signal(); + proxyChannel[0].flush(); + } +} diff --git a/docs/source/index.rst b/docs/source/index.rst index 3e2331907f0f2..683a4b3a2c344 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -89,6 +89,13 @@ Documentation quantization/auto_awq +.. toctree:: + :maxdepth: 1 + :caption: Splitwise + + splitwise/getting_started + splitwise/installing_mscclpp + .. toctree:: :maxdepth: 2 :caption: Developer Documentation diff --git a/docs/source/splitwise/getting_started.rst b/docs/source/splitwise/getting_started.rst new file mode 100644 index 0000000000000..9de34b3ed3271 --- /dev/null +++ b/docs/source/splitwise/getting_started.rst @@ -0,0 +1,16 @@ +.. _getting_started: + +Getting Started with Splitwise +============================== + +`Splitwise `_ is a technique to split the two phases of an LLM inference request - prompt processing and token generation - on to separate machines for efficient inference. + +Installing MSCCL++ +------------------------- + +Please follow :ref:`MSCCL++ installation instructions ` to install the MSCCL++ communication library used for implementing the communication of KV caches from prompt to token workers. + +Running inference with Splitwise +-------------------------------- + +Simply add ``--sep-prompt-token`` flag to the vLLM command in order to use Splitwise. \ No newline at end of file diff --git a/docs/source/splitwise/installing_mscclpp.rst b/docs/source/splitwise/installing_mscclpp.rst new file mode 100644 index 0000000000000..bc329af96c869 --- /dev/null +++ b/docs/source/splitwise/installing_mscclpp.rst @@ -0,0 +1,18 @@ +.. _installing_mscclpp: + +Installing MSCCL++ +============================ + +`MSCCL++ `_ is a GPU-driven communication stack for scalable AI applications. +It is used to implement KV cache communication in Splitwise. + +To install MSCCL++, please follow the instructions at `MSCCL++ Quickstart `_ or follow the steps below to install it from source: +MSCCL++ required libnuma, which can be installed using `apt install libnuma-dev` on Debian-based systems. + +.. code-block:: console + + $ git clone https://github.com/microsoft/mscclpp; + $ mkdir mscclpp/build; cd mscclpp/build; cmake -DCMAKE_BUILD_TYPE=Release ..; make -j; + $ conda install -c conda-forge mpi4py + $ cd ../python; pip install -r requirements_c12.txt; + $ cd ..; pip install -e . diff --git a/requirements.txt b/requirements.txt index 5684b2c29634d..62f9d167457a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ pydantic >= 2.0 # Required for OpenAI server. aioprometheus[starlette] pynvml == 11.5.0 triton >= 2.1.0 +gputil diff --git a/tests/distributed/test_kvcache_comm.py b/tests/distributed/test_kvcache_comm.py new file mode 100644 index 0000000000000..085b279857870 --- /dev/null +++ b/tests/distributed/test_kvcache_comm.py @@ -0,0 +1,42 @@ +"""Test the KV cache communication operators. + +Run `python test_kvcache_comm.py`. +""" +import argparse +import ray + +from vllm import EngineArgs, LLMEngine + + +def initialize_engine(args: argparse.Namespace) -> LLMEngine: + """Initialize the LLMEngine from the command line arguments.""" + engine_args = EngineArgs.from_cli_args(args) + return LLMEngine.from_engine_args(engine_args) + + +def run_all_workers(engine: LLMEngine, method: str, *args): + """Run all the workers.""" + ray_worker_outputs = [ + worker.execute_method.remote(method, *args) + for worker in engine.workers + ] + _ = getattr(engine.driver_worker, method)(*args) + ray.get(ray_worker_outputs) + + +"""Test the kv cache communication.""" +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Demo on using the LLMEngine class directly') + parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + args.model = "meta-llama/Llama-2-70b-hf" + args.tensor_parallel_size = 2 + args.sep_prompt_token = True + engine = initialize_engine(args) + + run_all_workers(engine, "set_gpucache") + run_all_workers(engine, "send_recv_kvcache_all") + run_all_workers(engine, "check_gpucache") + + engine.destroy_kvcache_comm() diff --git a/vllm/config.py b/vllm/config.py index 1dfc0d63c8813..a3058df434df2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -369,14 +369,22 @@ def __init__( worker_use_ray: bool, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, + sep_prompt_token: bool = False, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce + self.sep_prompt_token = sep_prompt_token self.world_size = pipeline_parallel_size * tensor_parallel_size + if sep_prompt_token: + # Half of the workers are prompt workers and the other half are token + self.num_prompt_workers = self.world_size + self.num_token_workers = self.world_size + self.world_size = self.num_prompt_workers + self.num_token_workers + if self.world_size > 1: self.worker_use_ray = True self._verify_args() diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4fdf9ec341cfd..4a125acba196c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,4 +1,4 @@ -from collections import deque +from collections import defaultdict, deque import enum import time from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set @@ -11,6 +11,7 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.prefix import PrefixPool +from vllm.utils import SeqToSlotMapper, coalesce_blocks_by_id logger = init_logger(__name__) @@ -38,6 +39,7 @@ def __init__( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], + blocks_to_nw: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], ) -> None: self.scheduled_seq_groups = scheduled_seq_groups @@ -46,6 +48,7 @@ def __init__( self.blocks_to_swap_in = blocks_to_swap_in self.blocks_to_swap_out = blocks_to_swap_out self.blocks_to_copy = blocks_to_copy + self.blocks_to_nw = blocks_to_nw # Swap in and swap out should never happen at the same time. assert not (blocks_to_swap_in and blocks_to_swap_out) self.ignored_seq_groups = ignored_seq_groups @@ -57,7 +60,8 @@ def __init__( def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in - and not self.blocks_to_swap_out and not self.blocks_to_copy) + and not self.blocks_to_swap_out and not self.blocks_to_copy + and not self.blocks_to_nw) def _sort_by_lora_ids(self) -> bool: self.scheduled_seq_groups = sorted( @@ -77,6 +81,7 @@ def __init__( scheduler_config: SchedulerConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], + track_prompt_blocks: bool = False, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -84,6 +89,10 @@ def __init__( # simple and NOT fair. It can lead to starvation of some # LoRAs. This should be improved in the future. self.lora_config = lora_config + self.track_prompt_blocks = track_prompt_blocks + self.seq_to_slot_mapper: Optional[SeqToSlotMapper] = None + if track_prompt_blocks: + self.seq_to_slot_mapper = SeqToSlotMapper() self.prompt_limit = min(self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) @@ -158,10 +167,11 @@ def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) def _schedule(self) -> SchedulerOutputs: - # Blocks that need to be swaped or copied before model execution. + # Blocks that need to be swaped, copied, or networked before model execution. blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} blocks_to_copy: Dict[int, List[int]] = {} + blocks_to_nw: Dict[int, List[int]] = defaultdict(list) # Fix the current time. now = time.monotonic() @@ -252,6 +262,15 @@ def _schedule(self) -> SchedulerOutputs: self.running.append(seq_group) num_curr_seqs += num_new_seqs scheduled.append(seq_group) + if self.track_prompt_blocks: + for seq in seq_group.get_seqs(): + # Populate blocks_to_nw for the sequences in prompt phase + # and first step of generation phase + if seq.get_output_len() <= 1: + block_ids = self.block_manager.get_block_table(seq) + slot_id = self.seq_to_slot_mapper.get_slot_id( + seq.seq_id) + blocks_to_nw[slot_id].extend(block_ids) self.waiting.extendleft(leftover_waiting_sequences) @@ -264,6 +283,7 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, + blocks_to_nw=coalesce_blocks_by_id(blocks_to_nw), ignored_seq_groups=ignored_seq_groups, ) return scheduler_outputs @@ -349,6 +369,17 @@ def _schedule(self) -> SchedulerOutputs: seq_group.num_seqs(status=SequenceStatus.RUNNING) for seq_group in self.running) + if self.track_prompt_blocks: + for seq_group in self.running: + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + # Populate blocks_to_nw for the sequences in prompt phase + # and first step of generation phase + if seq.get_output_len() <= 1: + block_ids = self.block_manager.get_block_table(seq) + slot_id = self.seq_to_slot_mapper.get_slot_id( + seq.seq_id) + blocks_to_nw[slot_id].extend(block_ids) + scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=self.running, prompt_run=False, @@ -356,6 +387,7 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, + blocks_to_nw=coalesce_blocks_by_id(blocks_to_nw), ignored_seq_groups=[], ) return scheduler_outputs @@ -393,6 +425,8 @@ def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) + if self.seq_to_slot_mapper is not None: + self.seq_to_slot_mapper.free_seq(seq.seq_id) def free_finished_seq_groups(self) -> None: self.running = deque(seq_group for seq_group in self.running @@ -402,6 +436,8 @@ def _allocate(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING + if self.seq_to_slot_mapper is not None: + self.seq_to_slot_mapper.set_seq(seq.seq_id) def _append_slot( self, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d5e63e25d6e85..6d0f733f97ffb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -24,6 +24,7 @@ class EngineArgs: pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None + sep_prompt_token: bool = False block_size: int = 16 swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 @@ -159,6 +160,10 @@ def add_cli_args( help='load model sequentially in multiple batches, ' 'to avoid RAM OOM when using tensor ' 'parallel and large models') + parser.add_argument( + '--sep-prompt-token', + action='store_true', + help='separate the prompt processing and token sampling') # KV cache arguments parser.add_argument('--block-size', type=int, @@ -294,7 +299,8 @@ def create_engine_configs( self.tensor_parallel_size, self.worker_use_ray, self.max_parallel_loading_workers, - self.disable_custom_all_reduce) + self.disable_custom_all_reduce, + self.sep_prompt_token) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 02c673c96fd9a..578a3707037c3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -116,8 +116,13 @@ def __init__( # Profile the memory usage and initialize the cache. self._init_cache() + if self.parallel_config.sep_prompt_token: + # Setup the MSCCL++ communication required for KV cache transfer + self._setup_kvcache_comm() + # Create the scheduler. - self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) + self.scheduler = Scheduler(scheduler_config, cache_config, lora_config, + self.parallel_config.sep_prompt_token) # Metric Logging. if self.log_stats: @@ -210,6 +215,22 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_node_and_gpu_ids = ray.get( [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) + driver_node_workers = [] + other_node_workers = [] + driver_worker_node_and_gpu_ids = [] + other_worker_node_and_gpu_ids = [] + + for worker, (node_id, gpu_ids) in zip(self.workers, + worker_node_and_gpu_ids): + if node_id == driver_node_id: + driver_node_workers.append(worker) + driver_worker_node_and_gpu_ids.append((node_id, gpu_ids)) + else: + other_node_workers.append(worker) + other_worker_node_and_gpu_ids.append((node_id, gpu_ids)) + self.workers = driver_node_workers + other_node_workers + worker_node_and_gpu_ids = driver_worker_node_and_gpu_ids + other_worker_node_and_gpu_ids + node_workers = defaultdict(list) node_gpus = defaultdict(list) @@ -229,6 +250,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) + mscclpp_init_method = f"eth0:{driver_ip}:{get_open_port()}" if self.parallel_config.sep_prompt_token else None # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker @@ -256,6 +278,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, + mscclpp_init_method=mscclpp_init_method, )) driver_rank = 0 @@ -270,6 +293,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, + mscclpp_init_method=mscclpp_init_method, is_driver_worker=True, ) @@ -349,6 +373,14 @@ def _init_cache(self) -> None: # if enforce_eager is False. self._run_workers("warm_up_model") + def _setup_kvcache_comm(self) -> None: + """Setup MSCCL++ communication connections for KV cache transfer.""" + self._run_workers("setup_kvcache_comm") + + def destroy_kvcache_comm(self) -> None: + """Stop MSCCL++ communication connections for KV cache transfer.""" + self._run_workers("destroy_kvcache_comm") + @classmethod def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" @@ -797,7 +829,25 @@ def step(self) -> List[RequestOutput]: """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - if not scheduler_outputs.is_empty(): + if scheduler_outputs.is_empty(): + output = [] + elif self.parallel_config.sep_prompt_token: + # TODO: This will only schedule one set of workers at a time and will not be + # able to take advantage of parallely running prompt and token workers. + all_outputs = self._run_stage_workers( + "execute_model", + prompt_stage=seq_group_metadata_list[0].is_prompt, + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, + "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, + "blocks_to_copy": scheduler_outputs.blocks_to_copy, + "blocks_to_nw": scheduler_outputs.blocks_to_nw, + }) + + # Only the driver worker returns the sampling results. + output = all_outputs[0] + else: # Execute the model. all_outputs = self._run_workers( "execute_model", @@ -806,12 +856,11 @@ def step(self) -> List[RequestOutput]: "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, "blocks_to_copy": scheduler_outputs.blocks_to_copy, + "blocks_to_nw": {}, }) # Only the driver worker returns the sampling results. output = all_outputs[0] - else: - output = [] return self._process_model_outputs(output, scheduler_outputs) @@ -994,3 +1043,60 @@ def _run_workers( ray_worker_outputs = ray.get(ray_worker_outputs) return [driver_worker_output] + ray_worker_outputs + + def _run_stage_workers( + self, + method: str, + prompt_stage: bool, + *args, + driver_args: Optional[List[Any]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on prompt workers or token workers.""" + + assert self.parallel_config.sep_prompt_token + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + if prompt_stage: + # Prompt workers include 1 driver worker and num_prompt_workers-1 ray workers. + ray_worker_outputs = [ + worker.execute_method.remote(method, *args, **kwargs) + for worker in + self.workers[:self.parallel_config.num_prompt_workers - 1] + ] + + # Start the driver worker after all the ray workers. + driver_worker_output = getattr(self.driver_worker, + method)(*driver_args, + **driver_kwargs) + + else: + # Token workers use worker[num_prompt_workers-1] as driver worker. + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *args, **kwargs) + for worker in + self.workers[self.parallel_config.num_prompt_workers:] + ] + + # Start the token driver worker after all the ray workers. + driver_worker = self.workers[ + self.parallel_config.num_prompt_workers - 1] + driver_worker_output = driver_worker.execute_method.remote( + method, *driver_args, **driver_kwargs) + driver_worker_output = ray.get(driver_worker_output) + + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return [driver_worker_output] + ray_worker_outputs diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index e1aac20b038b4..6bdc249eff7db 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -794,6 +794,7 @@ def __init__( self.hidden_size = hidden_size self.dtype = dtype self.device = device + self.dst_rank = 0 @property def vocab_size(self): @@ -897,7 +898,7 @@ def _get_logits( logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: logits += embedding_bias - logits = tensor_model_parallel_gather(logits) + logits = tensor_model_parallel_gather(logits, dst=self.dst_rank) if logits is None: return None @@ -938,6 +939,9 @@ def _get_logits( return logits + def set_dst_rank(self, dst_rank: int) -> None: + self.dst_rank = dst_rank + def forward(self, *args, **kwargs): return type(self.base_layer).forward(self, *args, **kwargs) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 2ce9d60f08d80..38914504b4e37 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -58,6 +58,11 @@ def __init__( raise ValueError(f"head_size ({self.head_size}) is not supported. " f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") + self.kvcache_comm_manager = None + + def set_kvcache_comm_manager(self, kvcache_comm_manager): + self.kvcache_comm_manager = kvcache_comm_manager + def forward( self, query: torch.Tensor, @@ -101,6 +106,12 @@ def forward( input_metadata.kv_cache_dtype, ) + if input_metadata.is_prompt and len(input_metadata.blocks_to_nw): + assert self.kvcache_comm_manager is not None + for semid in input_metadata.blocks_to_nw: + for block_start, num_blocks in input_metadata.blocks_to_nw[semid]: + self.kvcache_comm_manager.put(semid, self.layer_id, block_start, num_blocks) + if input_metadata.is_prompt: # Prompt run. if self.num_kv_heads != self.num_heads: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index bc86a916b5bbf..598dbc6b95c27 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -34,19 +34,26 @@ def __init__(self, self.vocab_size = vocab_size # original vocabulary size (without LoRA). self.org_vocab_size = org_vocab_size or vocab_size + self.dst_rank = 0 - def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + def _get_logits(self, + hidden_states: torch.Tensor, + embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor], + dst_rank: int = 0) -> torch.Tensor: # Get the logits for the next tokens. logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: logits += embedding_bias - logits = tensor_model_parallel_gather(logits) + logits = tensor_model_parallel_gather(logits, dst=dst_rank) # Remove paddings in vocab (if any). if logits is not None: logits = logits[:, :self.org_vocab_size] return logits + def set_dst_rank(self, dst_rank: int) -> None: + self.dst_rank = dst_rank + def forward( self, embedding: torch.Tensor, @@ -58,7 +65,10 @@ def forward( hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, embedding, embedding_bias) + logits = self._get_logits(hidden_states, + embedding, + embedding_bias, + dst_rank=self.dst_rank) # Only perform sampling in the driver worker. # Note: `_get_logits` is still distributed across TP workers because @@ -352,6 +362,8 @@ def _multinomial( probs = probs[:, None, :].expand(probs.shape[0], num_samples, probs.shape[1]).contiguous().view( -1, probs.shape[1]) + # torch.Tensor.exponential_ gives different results on different devices + # resulting in sampled outputs being different q = torch.empty_like(probs).exponential_(1) return probs.div_(q).argmax(dim=1).view(-1, num_samples) diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 65671994f3309..0808f8d84b68b 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -6,7 +6,6 @@ import torch from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tensor_model_parallel_group, ) @@ -82,7 +81,7 @@ def tensor_model_parallel_gather(input_: torch.Tensor, # Convert negative dim to positive. dim += input_.dim() # Allocate output tensor. - if get_tensor_model_parallel_rank() == dst: + if torch.distributed.get_rank() == dst: gather_list = [torch.empty_like(input_) for _ in range(world_size)] else: gather_list = None @@ -91,7 +90,7 @@ def tensor_model_parallel_gather(input_: torch.Tensor, gather_list, dst=dst, group=get_tensor_model_parallel_group()) - if get_tensor_model_parallel_rank() == dst: + if torch.distributed.get_rank() == dst: output_tensor = torch.cat(gather_list, dim=dim) else: output_tensor = None @@ -171,7 +170,7 @@ def broadcast_tensor_dict( for key, value in metadata_list: if isinstance(value, TensorMetadata): tensor = tensor_dict[key] - torch.distributed.broadcast(tensor, src=src) + torch.distributed.broadcast(tensor, src=src, group=group) else: recv_metadata_list = [None] torch.distributed.broadcast_object_list(recv_metadata_list, diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py index 628c151761fb2..fde72d9411ffe 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -6,7 +6,8 @@ from vllm.logger import init_logger from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank) + get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, + get_tensor_model_parallel_group) try: from vllm._C import custom_ar @@ -180,7 +181,9 @@ def _get_ipc_meta(self, inp: torch.Tensor): def _gather_ipc_meta(self, shard_data): all_data = [None] * self.world_size - dist.all_gather_object(all_data, shard_data) + dist.all_gather_object(all_data, + shard_data, + group=get_tensor_model_parallel_group()) handles = [] offsets = [] diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index 59cc196538571..f44ed5cf2272e 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -10,6 +10,8 @@ _TENSOR_MODEL_PARALLEL_GROUP = None # Pipeline model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None +# Stage parallel group that the current rank belongs to. +_STAGE_PARALLEL_GROUP = None # A list of global ranks for each pipeline group to ease calculation of the # source rank when broadcasting from the first or last pipeline stage. @@ -19,6 +21,7 @@ def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + sep_prompt_token: bool = False, ) -> None: """ Initialize model parallel groups. @@ -45,13 +48,16 @@ def initialize_model_parallel( # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() + # World size is scaled by two in case of separate prompt token machines. + scale_factor: int = 2 if sep_prompt_token else 1 - if (world_size != - tensor_model_parallel_size * pipeline_model_parallel_size): + if (world_size != tensor_model_parallel_size * + pipeline_model_parallel_size * scale_factor): raise RuntimeError( f"world_size ({world_size}) is not equal to " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " - f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + f"pipeline_model_parallel_size ({pipeline_model_parallel_size}) x " + f"scale_factor ({scale_factor})") num_tensor_model_parallel_groups: int = (world_size // tensor_model_parallel_size) @@ -82,10 +88,24 @@ def initialize_model_parallel( _PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_GLOBAL_RANKS = ranks + if sep_prompt_token: + global _STAGE_PARALLEL_GROUP + if num_tensor_model_parallel_groups == 2: + _STAGE_PARALLEL_GROUP = _TENSOR_MODEL_PARALLEL_GROUP + else: + prompt_group = torch.distributed.new_group(range(world_size // 2)) + token_group = torch.distributed.new_group( + range(world_size // 2, world_size)) + if rank < world_size // 2: + _STAGE_PARALLEL_GROUP = prompt_group + else: + _STAGE_PARALLEL_GROUP = token_group + def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, + sep_prompt_token: bool = False, ) -> None: """Helper to initialize model parallel groups if they are not initialized, or ensure tensor-parallel and pipeline-parallel sizes are equal to expected @@ -93,7 +113,8 @@ def ensure_model_parallel_initialized( """ if not model_parallel_is_initialized(): initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size) + pipeline_model_parallel_size, + sep_prompt_token) return assert ( @@ -128,6 +149,12 @@ def get_pipeline_model_parallel_group(): return _PIPELINE_MODEL_PARALLEL_GROUP +def get_stage_parallel_group(): + """Get the stage parallel group the caller rank belongs to.""" + # _STAGE_PARALLEL_GROUP can be None (indicating no stage parallelism) + return _STAGE_PARALLEL_GROUP + + def get_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" return torch.distributed.get_world_size( @@ -140,6 +167,11 @@ def get_pipeline_model_parallel_world_size(): group=get_pipeline_model_parallel_group()) +def get_stage_parallel_world_size(): + """Return world size for the stage parallel group.""" + return torch.distributed.get_world_size(group=get_stage_parallel_group()) + + def get_tensor_model_parallel_rank(): """Return my rank for the tensor model parallel group.""" return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) @@ -151,6 +183,11 @@ def get_pipeline_model_parallel_rank(): group=get_pipeline_model_parallel_group()) +def get_stage_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + return torch.distributed.get_rank(group=get_stage_parallel_group()) + + def get_tensor_model_parallel_src_rank(): """Calculate the global rank corresponding to the first local rank in the tensor model parallel group.""" @@ -206,3 +243,7 @@ def destroy_model_parallel(): _PIPELINE_MODEL_PARALLEL_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None + global _STAGE_PARALLEL_GROUP + if _STAGE_PARALLEL_GROUP: + torch.distributed.destroy_process_group(_STAGE_PARALLEL_GROUP) + _STAGE_PARALLEL_GROUP = None diff --git a/vllm/utils.py b/vllm/utils.py index 9e9126a2d6377..feb407f0f44ef 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -4,9 +4,10 @@ import subprocess import uuid from platform import uname -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union from packaging.version import parse, Version +import GPUtil import psutil import torch import asyncio @@ -37,6 +38,12 @@ class Device(enum.Enum): CPU = enum.auto() +class WorkerType(enum.Enum): + PROMPT = enum.auto() + TOKEN = enum.auto() + MIXED = enum.auto() + + class Counter: def __init__(self, start: int = 0) -> None: @@ -114,6 +121,44 @@ def clear(self): self.cache.clear() +# Maximum number of sequences that can be in the system at a time +# Only used when sep_prompt_token is set in Parallel Config +# It determines the number of semaphores that will be used for KV cache transfer using MSCCL++ +# This can be changed based on need +MAX_SLOT_IDS = 256 + + +class SeqToSlotMapper: + """ SeqToSlotMapper maps sequence ids to slot ids that range from 0 to MAX_SLOT_IDS-1. + Occupied slots indicate that their associated sequence is currently being processed in + the system. A slot is used to index into the list of semaphores used to synchronize the + KV cache transfer using MSCCL++. Each sequence is mapped to a different slot/semaphore + in order to allow fine-grained synchronization. + + Slots are freed once their sequence finishes, allowing the same semaphores to be + used for new sequences as they arrive. + """ + + def __init__(self): + self.available_slotids = list(range(MAX_SLOT_IDS)) + self.seq_to_slot = {} + + def set_seq(self, seq_id): + try: + slot_id = self.available_slotids.pop(0) + except IndexError: + raise RuntimeError( + "No more slots available. Increase MAX_SLOT_IDS.") + self.seq_to_slot[seq_id] = slot_id + + def free_seq(self, seq_id): + slot_id = self.seq_to_slot.pop(seq_id) + self.available_slotids.insert(0, slot_id) + + def get_slot_id(self, seq_id): + return self.seq_to_slot[seq_id] + + def is_hip() -> bool: return torch.version.hip is not None @@ -276,3 +321,33 @@ def create_kv_caches_with_random( _generate_random_fp8_e5m2(value_cache, -scale, scale) value_caches.append(value_cache) return key_caches, value_caches + + +def get_total_num_gpus() -> int: + return len(GPUtil.getGPUs()) + + +def coalesce_blocks(block_list: List[int]): + '''Coalesce of list of blocks to exploit contiguous chunks. + ''' + if not block_list: + return [] + sorted_block_list = sorted(block_list) + ret = [] + current_block_start = sorted_block_list[0] + current_block_length = 1 + for i in range(1, len(sorted_block_list)): + if sorted_block_list[i] == sorted_block_list[i - 1] + 1: + current_block_length += 1 + else: + ret.append((current_block_start, current_block_length)) + current_block_start = sorted_block_list[i] + current_block_length = 1 + ret.append((current_block_start, current_block_length)) + return ret + + +def coalesce_blocks_by_id(blocks_to_nw_dict: Dict[int, List[int]]): + for cur_id in blocks_to_nw_dict: + blocks_to_nw_dict[cur_id] = coalesce_blocks(blocks_to_nw_dict[cur_id]) + return blocks_to_nw_dict diff --git a/vllm/worker/comm_utils.py b/vllm/worker/comm_utils.py new file mode 100644 index 0000000000000..87fa766277720 --- /dev/null +++ b/vllm/worker/comm_utils.py @@ -0,0 +1,222 @@ +import cupy as cp +import os + +from vllm.utils import get_total_num_gpus, MAX_SLOT_IDS + +try: + import mscclpp.comm as mscclpp_comm + from mscclpp.utils import KernelBuilder, pack +except ImportError: + raise ImportError( + "MSCCL++ is not installed. Please install MSCCL++ to use this feature." + ) + +# Flush MSCCL++ fifo every 128 operations +FLUSH_COUNT = 128 + +HEAD_TYPES = [0, 1] # 0 for keys, 1 for values + +KERNEL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/../../csrc" + + +class SendKVKernel: + """ SendKVKernel is a wrapper around a CUDA kernel that uses + MSCCL++ proxy channels to asynchronously send key-value cache + """ + + def __init__(self): + self._kernel = KernelBuilder( + file="kv_comm_kernels.cu", + kernel_name="nw_cache_out_kernel", + file_dir=KERNEL_DIR).get_compiled_kernel() + self.nblocks = 1 + self.nthreads = 1 + + # nw_cache_out_kernel takes device handles, memory offset, memory size, + # and flush flag as parameters + def __call__(self, params): + return self._kernel.launch_kernel(params, + self.nblocks, + self.nthreads, + shared=0, + stream=None) + + +class SignalKVKernel: + """ SignalKVKernel is a wrapper around a CUDA kernel that signals + the semaphore associated with the MSCCL++ proxy channel + """ + + def __init__(self): + self._kernel = KernelBuilder( + file="kv_comm_kernels.cu", + kernel_name="nw_cache_out_signal_kernel", + file_dir=KERNEL_DIR).get_compiled_kernel() + self.nblocks = 1 + self.nthreads = 1 + + # nw_cache_out_signal_kernel takes device handles of proxy channels + # as parameters + def __call__(self, params): + return self._kernel.launch_kernel(params, + self.nblocks, + self.nthreads, + shared=0, + stream=None) + + +class WaitKVKernel: + """ WaitKVKernel is a wrapper around a CUDA kernel that waits on + the semaphore associated with the MSCCL++ proxy channel + """ + + def __init__(self): + self._kernel = KernelBuilder( + file="kv_comm_kernels.cu", + kernel_name="nw_cache_in_kernel", + file_dir=KERNEL_DIR).get_compiled_kernel() + self.nblocks = 1 + self.nthreads = 1 + + # nw_cache_in_kernel takes device handles of proxy channels as parameters + def __call__(self, params): + return self._kernel.launch_kernel(params, + self.nblocks, + self.nthreads, + shared=0, + stream=None) + + +class KVCacheCommunicator: + """ KVCacheCommunicator provides an interface to communicate the KV cache + between prompt and token workers using MSCCL++ proxy channels. + + block_size: int - size of a single KV cache block + device_handles: dict - device handles of MSCCL++ proxy channels + flush_counter: int - counter to keep track of number of operations + memory_ids: dict - memory ids of KV cache on prompt and token workers + my_rank: int - rank of the prompt worker + remote_rank: int - rank of the token worker + + SendKVKernel and SignalKVKernel put KV cache data and signal semaphores on the prompt side + WaitKVKernel waits on semaphores on the token side. + """ + + def __init__(self, block_size, device_handles, memory_ids, my_rank, + remote_rank): + self.block_size = block_size + self.device_handles = device_handles + self.memory_ids = memory_ids + self.my_rank = my_rank + self.remote_rank = remote_rank + self.flush_counter = 0 + self.send_kernel = SendKVKernel() + self.signal_kernel = SignalKVKernel() + self.wait_kernel = WaitKVKernel() + + def get_device_handles(self, sem_ids): + device_handles = [self.device_handles[sem_id] for sem_id in sem_ids] + return cp.asarray(memoryview(b"".join(device_handles)), dtype=cp.uint8) + + def wait(self, sem_id): + dh = self.get_device_handles([sem_id]) + params = pack(dh) + self.wait_kernel(params) + + def signal_and_flush(self, sem_id): + dh = self.get_device_handles([sem_id]) + params = pack(dh) + self.signal_kernel(params) + self.flush_counter = 0 + + def put(self, sem_id, layer_id, block_start, num_blocks): + block_size = self.block_size + remote_rank = self.remote_rank + my_rank = self.my_rank + for head_type in HEAD_TYPES: + block_offset = block_start * block_size + dh = self.get_device_handles([sem_id]) + self.flush_counter += 1 + flush = self.flush_counter >= FLUSH_COUNT + if flush: + self.flush_counter = 0 + params = pack(dh, + self.memory_ids[layer_id][head_type][remote_rank], + self.memory_ids[layer_id][head_type][my_rank], + block_offset, block_size * num_blocks, flush) + self.send_kernel(params) + + +class KVCacheCommManager: + + def __init__(self, rank, world_size, num_prompt_workers, + mscclpp_init_method) -> None: + self.kvcache_comm = None + self.proxy_service = None + + # Initialize the MSCCL++ group. + self.mscclpp_group = mscclpp_comm.CommGroup( + rank=rank, + size=world_size, + interfaceIpPortTrio=mscclpp_init_method, + ) + + # Setup up connections. + self.corr_worker_rank = (rank + num_prompt_workers) % world_size + transport = self.mscclpp_group.my_ib_device(rank % + get_total_num_gpus()) + self.mscclpp_conns = self.mscclpp_group.make_connection( + [self.corr_worker_rank], transport) + + def setup_comm(self, num_layers, kv_cache) -> None: + # Set up proxy service and proxy channels for KV cache communication. + self.proxy_service = mscclpp_comm.ProxyService() + self.proxy_service.start_proxy() + + # register KV cache memory with MSCCL++ proxy channel + memory_ids = [[None, None] for _ in range(num_layers)] + for layer_id in range(num_layers): + for head_type in HEAD_TYPES: + memory_ids[layer_id][ + head_type] = self.mscclpp_group.register_memory_with_proxy( + self.proxy_service, + kv_cache[layer_id][head_type], + self.mscclpp_conns, + ) + + # register semaphores with MSCCL++ proxy channel + # one for each sequence + proxy_channels = [None for _ in range(MAX_SLOT_IDS)] + device_handles = [None for _ in range(MAX_SLOT_IDS)] + for sem_id in range(MAX_SLOT_IDS): + proxy_channels[ + sem_id] = self.mscclpp_group.register_semaphore_with_proxy( + self.proxy_service, + self.mscclpp_conns, + )[self.corr_worker_rank] + device_handles[sem_id] = proxy_channels[sem_id].device_handle().raw + + all_blocks_size = (kv_cache[0][0].numel() * + kv_cache[0][0].element_size()) + block_size = all_blocks_size // kv_cache[0][0].size(0) + + # Set up KV cache communicator. + self.kvcache_comm = KVCacheCommunicator(block_size, device_handles, + memory_ids, + self.mscclpp_group.my_rank, + self.corr_worker_rank) + + def destroy_comm(self) -> None: + self.proxy_service.stop_proxy() + del self.proxy_service + del self.kvcache_comm + del self.mscclpp_group + + def wait(self, sem_id): + self.kvcache_comm.wait(sem_id) + + def signal_and_flush(self, sem_id): + self.kvcache_comm.signal_and_flush(sem_id) + + def put(self, sem_id, layer_id, block_start, num_blocks): + self.kvcache_comm.put(sem_id, layer_id, block_start, num_blocks) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fce0009e3097d..dcac770ef582a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -11,6 +11,7 @@ from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils import custom_all_reduce +from vllm.model_executor.parallel_utils.parallel_state import get_stage_parallel_group from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager @@ -45,6 +46,7 @@ def __init__( self.scheduler_config = scheduler_config self.lora_config = lora_config self.is_driver_worker = is_driver_worker + self.driver_rank = 0 # model_config can be None in tests/samplers/test_sampler.py. # FIXME(woosuk): This is a hack to make the tests work. Refactor this. @@ -446,8 +448,10 @@ def _prepare_sample( def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + blocks_to_nw: Optional[Dict[int, List[int]]], ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, Set[int], LoRAMapping]: + stage_group = get_stage_parallel_group() if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -497,9 +501,12 @@ def prepare_input_tensors( "lora_requests": lora_requests, "lora_mapping": lora_mapping, } - broadcast_tensor_dict(metadata_dict, src=0) + broadcast_tensor_dict(metadata_dict, + src=self.driver_rank, + group=stage_group) else: - metadata_dict = broadcast_tensor_dict(src=0) + metadata_dict = broadcast_tensor_dict(src=self.driver_rank, + group=stage_group) input_tokens = metadata_dict["input_tokens"] input_positions = metadata_dict["input_positions"] lora_mapping = metadata_dict["lora_mapping"] @@ -525,6 +532,7 @@ def prepare_input_tensors( perform_sampling=False, ) + input_metadata.blocks_to_nw = blocks_to_nw return (input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping) @@ -533,10 +541,12 @@ def execute_model( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + blocks_to_nw: Dict[int, List[int]] = {}, ) -> Optional[SamplerOutput]: (input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, - lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) + lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list, + blocks_to_nw) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c97e82a55a1ee..5fe2986389cae 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -13,8 +13,9 @@ broadcast_tensor_dict) from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized) + ensure_model_parallel_initialized, get_stage_parallel_group) from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import WorkerType from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner from vllm.lora.request import LoRARequest @@ -39,6 +40,8 @@ def __init__( distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, kv_cache_dtype: Optional[str] = "auto", + mscclpp_init_method: str = None, + worker_type: WorkerType = WorkerType.MIXED, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -49,6 +52,8 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.mscclpp_init_method = mscclpp_init_method + self.worker_type = worker_type self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -67,6 +72,17 @@ def __init__( self.cache_events = None self.gpu_cache = None + self.kvcache_comm_manager = None + + def is_prompt_worker(self) -> bool: + return self.worker_type == WorkerType.PROMPT + + def is_token_worker(self) -> bool: + return self.worker_type == WorkerType.TOKEN + + def is_mixed_worker(self) -> bool: + return self.worker_type == WorkerType.MIXED + def init_model(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until @@ -89,13 +105,71 @@ def init_model(self) -> None: # Initialize the distributed environment. init_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method) + self.init_kvcache_comm(self.mscclpp_init_method) if not self.parallel_config.disable_custom_all_reduce: init_custom_ar() + # Initialize the model. set_random_seed(self.model_config.seed) def load_model(self): self.model_runner.load_model() + if self.parallel_config.sep_prompt_token: + # Populate Sampler with dst_rank as driver worker's rank. + self.model_runner.model.sampler.set_dst_rank( + self.model_runner.driver_rank) + + def init_kvcache_comm(self, + mscclpp_init_method: Optional[str] = None) -> None: + if mscclpp_init_method is not None: + from vllm.worker.comm_utils import KVCacheCommManager + self.kvcache_comm_manager = KVCacheCommManager( + self.rank, self.parallel_config.world_size, + self.parallel_config.num_prompt_workers, mscclpp_init_method) + + self.worker_type = (WorkerType.PROMPT if self.rank < + self.parallel_config.num_prompt_workers else + WorkerType.TOKEN) + + # Set the driver worker rank for prompt and token workers. + self.model_runner.driver_rank = ( + self.rank // self.parallel_config.num_prompt_workers + ) * self.parallel_config.num_prompt_workers + if self.rank == self.model_runner.driver_rank: + self.is_driver_worker = True + self.model_runner.is_driver_worker = True + + def setup_kvcache_comm(self) -> None: + # Setup the communication for the KV cache. + if self.kvcache_comm_manager is not None: + num_layers = self.model_config.get_num_layers(self.parallel_config) + self.kvcache_comm_manager.setup_comm(num_layers, self.gpu_cache) + + # Populate the attention modules with the KV cache communicator. + self.set_comm_for_attention_modules() + + def destroy_kvcache_comm(self) -> None: + if self.kvcache_comm_manager is not None: + self.kvcache_comm_manager.destroy_comm() + self.unset_comm_for_attention_modules() + + def set_comm_for_attention_modules(self) -> None: + attention_modules = list( + filter( + lambda module: "PagedAttention" in module.__class__.__name__, + self.model_runner.model.modules())) + for i, attention_module in enumerate(attention_modules): + attention_module.set_kvcache_comm_manager( + self.kvcache_comm_manager) + attention_module.layer_id = i + + def unset_comm_for_attention_modules(self) -> None: + attention_modules = list( + filter( + lambda module: "PagedAttention" in module.__class__.__name__, + self.model_runner.model.modules())) + for attention_module in attention_modules: + del attention_module.kvcache_comm_manager @torch.inference_mode() def profile_num_available_blocks( @@ -189,26 +263,38 @@ def execute_model( blocks_to_swap_in: Optional[Dict[int, int]] = None, blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, + blocks_to_nw: Optional[Dict[int, List[int]]] = None, ) -> Optional[SamplerOutput]: + is_prompt = False if self.is_driver_worker: assert seq_group_metadata_list is not None + is_prompt = seq_group_metadata_list[0].is_prompt + if self.is_driver_worker and self.should_execute(is_prompt): num_seq_groups = len(seq_group_metadata_list) assert blocks_to_swap_in is not None assert blocks_to_swap_out is not None assert blocks_to_copy is not None + assert blocks_to_nw is not None data = { "num_seq_groups": num_seq_groups, "blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_out": blocks_to_swap_out, "blocks_to_copy": blocks_to_copy, + "blocks_to_nw": blocks_to_nw, + "is_prompt": is_prompt, } - broadcast_tensor_dict(data, src=0) + broadcast_tensor_dict(data, + src=self.model_runner.driver_rank, + group=get_stage_parallel_group()) else: - data = broadcast_tensor_dict(src=0) + data = broadcast_tensor_dict(src=self.model_runner.driver_rank, + group=get_stage_parallel_group()) num_seq_groups = data["num_seq_groups"] blocks_to_swap_in = data["blocks_to_swap_in"] blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_copy = data["blocks_to_copy"] + blocks_to_nw = data["blocks_to_nw"] + is_prompt = data["is_prompt"] self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) @@ -216,8 +302,16 @@ def execute_model( if num_seq_groups == 0: return {} + if len(blocks_to_nw) and self.is_token_worker() and not is_prompt: + for sem_id in blocks_to_nw: + self.kvcache_comm_manager.wait(sem_id) + output = self.model_runner.execute_model(seq_group_metadata_list, - self.gpu_cache) + self.gpu_cache, blocks_to_nw) + + if len(blocks_to_nw) and self.is_prompt_worker() and is_prompt: + for sem_id in blocks_to_nw: + self.kvcache_comm_manager.signal_and_flush(sem_id) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -229,6 +323,49 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.model_runner.list_loras() + def should_execute(self, is_prompt: bool) -> bool: + return self.is_mixed_worker() or ( + self.is_prompt_worker() and is_prompt) or (self.is_token_worker() + and not is_prompt) + + def set_gpucache(self): + from vllm.worker.comm_utils import HEAD_TYPES + num_layers = self.model_config.get_num_layers(self.parallel_config) + for layer_id in range(num_layers): + for head_type in HEAD_TYPES: + self.gpu_cache[layer_id][head_type][:] = self.rank * ( + num_layers * + len(HEAD_TYPES)) + layer_id * len(HEAD_TYPES) + head_type + torch.cuda.synchronize() + + def send_recv_kvcache_all(self): + if self.kvcache_comm_manager is not None: + num_gpu_blocks = self.cache_config.num_gpu_blocks + num_layers = self.model_config.get_num_layers(self.parallel_config) + if self.rank < self.parallel_config.num_prompt_workers: + for layer_id in range(num_layers): + self.kvcache_comm_manager.put(0, layer_id, 0, + num_gpu_blocks) + self.kvcache_comm_manager.signal_and_flush(0) + else: + self.kvcache_comm_manager.wait(0) + torch.cuda.synchronize() + + def check_gpucache(self): + if self.kvcache_comm_manager is not None: + from vllm.worker.comm_utils import HEAD_TYPES + num_prompt_workers = self.parallel_config.num_prompt_workers + num_layers = self.model_config.get_num_layers(self.parallel_config) + expected_worker_id = self.rank if self.rank < num_prompt_workers else self.rank - num_prompt_workers + for layer_id in range(num_layers): + for head_type in HEAD_TYPES: + expected_scalar = expected_worker_id * (num_layers * len( + HEAD_TYPES)) + layer_id * len(HEAD_TYPES) + head_type + expected_tensor = torch.ones_like( + self.gpu_cache[layer_id][head_type]) * expected_scalar + assert torch.allclose(self.gpu_cache[layer_id][head_type], + expected_tensor) + def init_distributed_environment( parallel_config: ParallelConfig, @@ -258,7 +395,8 @@ def init_distributed_environment( # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + parallel_config.pipeline_parallel_size, + parallel_config.sep_prompt_token) def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):