From 0eb1ab1293aafdabcb6bd555b6b85a767b21faac Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 6 May 2024 16:46:15 +0000 Subject: [PATCH 01/39] flashinfer for prefill --- .buildkite/test-pipeline.yaml | 4 ++ Dockerfile | 3 ++ vllm/attention/backends/flashinfer.py | 38 ++++++++++++----- vllm/worker/model_runner.py | 60 ++++++++++++++++++++++++--- 4 files changed, 88 insertions(+), 17 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e49a5650c44ea..767780f4c8ac8 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -15,6 +15,8 @@ steps: commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py @@ -38,6 +40,8 @@ steps: - pytest -v -s test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py diff --git a/Dockerfile b/Dockerfile index 90be3a30f89b1..563d6edf6caa2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,6 +37,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ # see https://github.com/pytorch/pytorch/pull/123243 ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} + +# Manually install flashinfer +pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ #################### BASE BUILD IMAGE #################### diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8ab4b1f12ee36..4e26a63221c6a 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -5,10 +5,12 @@ import flashinfer from flash_attn import flash_attn_varlen_func from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper except ImportError: flashinfer = None flash_attn_varlen_func = None BatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None import torch @@ -64,6 +66,7 @@ class FlashInferMetadata(AttentionMetadataPerStage): use_cuda_graph: bool = False + prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None # Metadata for the prefill stage since we still @@ -112,10 +115,20 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") - # When using flashinfer, we are also creating the FlashInferMetadata, - # which will also call post_init by default, here we want to skip the - # post_init if it's the prefill phase. - if not self.is_prompt: + if self.is_prompt: + self.prefill_wrapper = \ + flashinfer.BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD") + self.prefill_wrapper.begin_forward( + self.seq_start_loc, + self.paged_kv_indptr, + self.paged_kv_indices, + self.paged_kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + ) + else: self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, "NHD") self.decode_wrapper.begin_forward( @@ -135,8 +148,9 @@ def asdict_zerocopy(self, ) -> Dict[str, Any]: if skip_fields is None: skip_fields = set() - # We need to skip the decode_wrapper field since it cannot be + # We need to skip the prefill/decode_wrapper field since it cannot be # broadcasted with nccl when TP is enabled. + skip_fields.add('prefill_wrapper') skip_fields.add('decode_wrapper') return super().asdict_zerocopy(skip_fields) @@ -188,9 +202,10 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, attn_metadata.kv_cache_dtype, ) + query = query.contiguous( + ) # Flashinfer requires query to be contiguous if prefill_meta := attn_metadata.prefill_metadata: - assert prefill_meta.block_tables is not None - if kv_cache is None or prefill_meta.block_tables.numel() == 0: + if kv_cache is None: output = flash_attn_varlen_func( q=query, k=key, @@ -205,13 +220,14 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, alibi_slopes=self.alibi_slopes, ) else: - raise NotImplementedError( - "Prefix caching is not supported with flashinfer yet.") + assert attn_metadata.prefill_metadata is not None + assert attn_metadata.prefill_metadata.prefill_wrapper \ + is not None + output = attn_metadata.prefill_metadata.prefill_wrapper.forward( + query, kv_cache) else: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None - query = query.contiguous( - ) # Flashinfer requires query to be contiguous output = attn_metadata.decode_metadata.decode_wrapper.forward( query, kv_cache, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ab248596490f6..46843364a8a71 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -157,7 +157,8 @@ def __init__( self.graph_block_tables: torch.Tensor # Set after initial profiling. # Set if the backend is flashinfer. - self.flashinfer_workspace_buffer: torch.Tensor + self.flashinfer_prefill_workspace_buffer: torch.Tensor + self.flashinfer_decode_workspace_buffer: torch.Tensor def load_model(self) -> None: with CudaMemoryProfiler() as m: @@ -240,6 +241,24 @@ def _prepare_prompt( prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] + # The following fields are only for flashinfer + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + paged_kv_last_page_len: List[int] = [] + if len(seq_group_metadata_list) == 0: return PreparePromptMetadata.empty() @@ -319,6 +338,8 @@ def _prepare_prompt( # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table)) # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, seq_len - sliding_window). @@ -342,6 +363,11 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + max_query_len = max(query_lens) max_seq_len = max(seq_lens) assert max_query_len > 0 @@ -396,12 +422,34 @@ def _prepare_prompt( out=seq_start_loc[1:]) if self.attn_backend is FlashInferBackend: + if not hasattr(self, "flashinfer_prefill_workspace_buffer"): + # Allocate 16MB workspace buffer + # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html + self.flashinfer_prefill_workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) + paged_kv_indptr = torch.tensor(paged_kv_indptr, + dtype=torch.int, + device=self.device) + paged_kv_indices = torch.tensor(paged_kv_indices, + dtype=torch.int, + device=self.device) + paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len, + dtype=torch.int, + device=self.device) attn_metadata = self.attn_backend.make_metadata( + workspace_buffer=self.flashinfer_prefill_workspace_buffer, is_prompt=True, use_cuda_graph=False, seq_start_loc=seq_start_loc, - max_seq_len=max_seq_len, - block_tables=block_tables) + paged_kv_indices=paged_kv_indices, + paged_kv_indptr=paged_kv_indptr, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=self.model_config.get_num_attention_heads( + self.parallel_config), + num_kv_heads=self.model_config.get_num_kv_heads( + self.parallel_config), + head_dim=self.model_config.get_head_size(), + max_seq_len=max_seq_len) else: attn_metadata = self.attn_backend.make_metadata( is_prompt=True, @@ -557,10 +605,10 @@ def _prepare_decode( ) if self.attn_backend is FlashInferBackend: - if not hasattr(self, "flashinfer_workspace_buffer"): + if not hasattr(self, "flashinfer_decode_workspace_buffer"): # Allocate 16MB workspace buffer # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html - self.flashinfer_workspace_buffer = torch.empty( + self.flashinfer_decode_workspace_buffer = torch.empty( 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) paged_kv_indptr = torch.tensor(paged_kv_indptr, dtype=torch.int, @@ -577,7 +625,7 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, use_cuda_graph=False, - workspace_buffer=self.flashinfer_workspace_buffer, + workspace_buffer=self.flashinfer_decode_workspace_buffer, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, From 4590b467e15bb674ee57577ba67af5db4b4408cf Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 6 May 2024 16:57:35 +0000 Subject: [PATCH 02/39] minor --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 563d6edf6caa2..83c8d78522f8b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,7 +39,7 @@ ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} # Manually install flashinfer -pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ +RUN pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ #################### BASE BUILD IMAGE #################### From 3bfbdf70089958b1f1d73fd1f7bc70016e125d6e Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 7 May 2024 03:59:45 +0000 Subject: [PATCH 03/39] fix docker --- Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 83c8d78522f8b..2c62fc529b012 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,9 +37,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ # see https://github.com/pytorch/pytorch/pull/123243 ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} - -# Manually install flashinfer -RUN pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ #################### BASE BUILD IMAGE #################### @@ -129,6 +126,9 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ --mount=type=cache,target=/root/.cache/pip \ pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir + +# Manually install flashinfer +RUN pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ #################### vLLM installation IMAGE #################### From 993a4aed05f2f67eb7e9cd562e69df88111a9b3c Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 7 May 2024 22:17:10 +0000 Subject: [PATCH 04/39] work for prefix caching --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 46843364a8a71..46c2d2d11fa3d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -440,7 +440,7 @@ def _prepare_prompt( workspace_buffer=self.flashinfer_prefill_workspace_buffer, is_prompt=True, use_cuda_graph=False, - seq_start_loc=seq_start_loc, + seq_start_loc=subquery_start_loc, paged_kv_indices=paged_kv_indices, paged_kv_indptr=paged_kv_indptr, paged_kv_last_page_len=paged_kv_last_page_len, From b4d9daec7cff5fd8baca8219d24c51d768f25cd8 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 7 May 2024 22:18:54 +0000 Subject: [PATCH 05/39] dedup test --- .buildkite/test-pipeline.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 767780f4c8ac8..9926cea8441b3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -16,7 +16,6 @@ steps: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py From 5e3d11d1b6db333047d6443e596e2e9e4b703aed Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 29 May 2024 00:02:39 +0000 Subject: [PATCH 06/39] format --- vllm/attention/backends/flashinfer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index d8902d4d56315..9e71c28bad9b7 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -3,9 +3,9 @@ try: import flashinfer + from flash_attn import flash_attn_varlen_func from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - from flash_attn import flash_attn_varlen_func except ImportError: flashinfer = None flash_attn_varlen_func = None @@ -13,6 +13,7 @@ BatchPrefillWithPagedKVCacheWrapper = None import torch + from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) @@ -234,8 +235,10 @@ def forward( query = query.contiguous( ) # Flashinfer requires query to be contiguous if prefill_meta := attn_metadata.prefill_metadata: - # We will use flash attention for prefill when kv_cache is not provided. - # This happens when vllm runs the profiling to determine the number of blocks. + # We will use flash attention for prefill + # when kv_cache is not provided. + # This happens when vllm runs the profiling to + # determine the number of blocks. if kv_cache is None: output = flash_attn_varlen_func( q=query, From 89f0e2c0c270638914cdd85f85d372dee1e51aaa Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 29 May 2024 02:17:41 +0000 Subject: [PATCH 07/39] fix test --- tests/basic_correctness/test_basic_correctness.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 7d8117447ca0a..7bff01d8a58e8 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -18,7 +18,11 @@ def test_vllm_gc_ed(): """Verify vllm instance is GC'ed when it is deleted""" - llm = LLM("facebook/opt-125m") + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + enforce_eager = False + if backend_by_env_var == "FLASHINFER": + enforce_eager = True + llm = LLM("facebook/opt-125m", enforce_eager=enforce_eager) weak_llm = weakref.ref(llm) del llm # If there's any circular reference to vllm, this fails From 72e704bb710bc75a7c180b522f20408480e684ec Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 30 May 2024 00:01:00 +0000 Subject: [PATCH 08/39] remove flashinfer from ci --- .buildkite/test-pipeline.yaml | 1 - Dockerfile | 2 -- 2 files changed, 3 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 27ec4c9b0bb96..def8a460e84a7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -18,7 +18,6 @@ steps: commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py diff --git a/Dockerfile b/Dockerfile index c4ac525c0dce9..1f001265a994c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -106,8 +106,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist --mount=type=cache,target=/root/.cache/pip \ pip install dist/*.whl --verbose -# Manually install flashinfer -RUN pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ #################### vLLM installation IMAGE #################### From f9770ed6e111f4228d19bef7a34804f31f47029a Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 4 Jun 2024 04:14:13 +0000 Subject: [PATCH 09/39] wip, cuda graph for decode --- vllm/attention/backends/flashinfer.py | 61 ++++++++++-- vllm/attention/selector.py | 2 - vllm/worker/model_runner.py | 130 ++++++++++++++++++-------- 3 files changed, 141 insertions(+), 52 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 9e71c28bad9b7..2ed8117463f29 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -4,6 +4,7 @@ try: import flashinfer from flash_attn import flash_attn_varlen_func + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper except ImportError: @@ -71,6 +72,8 @@ class FlashInferMetadata(AttentionMetadata): use_cuda_graph: bool = False prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None + decode_cudagraph_wrapper: Optional[ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None # Metadata for the prefill stage @@ -78,10 +81,6 @@ class FlashInferMetadata(AttentionMetadata): query_start_loc: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None - # Metadata for the decode stage - # Workspace buffer required by the kernel, the buffer should not - # be allocated/deacollated by the FalshInfermetadata object. - workspace_buffer: Optional[torch.Tensor] = None # An example for paged_kv_indices, paged_kv_indptr: # request 1, page indices [0, 5, 8] # request 2, page indices [1, 6, 7] @@ -107,6 +106,10 @@ class FlashInferMetadata(AttentionMetadata): page_size: Optional[int] = None # The data type of the paged kv cache data_type: torch.dtype = None + device: torch.device = torch.device("cuda") + num_gpu_blocks: Optional[int] = None + max_num_seqs: Optional[int] = None + use_captured_graph: Optional[bool] = None def __post_init__(self): # Refer to @@ -119,16 +122,56 @@ def __post_init__(self): f"received {self.head_dim}.") if self.num_prefill_tokens > 0: + if not hasattr(self, "prefill_workspace_buffer"): + # Allocate 16MB workspace buffer + # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html + self.prefill_workspace_buffer = torch.empty(16 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) self.prefill_wrapper = \ - flashinfer.BatchPrefillWithPagedKVCacheWrapper( - self.workspace_buffer, "NHD") + BatchPrefillWithPagedKVCacheWrapper( + self.prefill_workspace_buffer, "NHD") + assert self.prefill_workspace_buffer is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) self.prefill_wrapper.begin_forward( self.query_start_loc, self.paged_kv_indptr, self.paged_kv_indices, self.paged_kv_last_page_len, - self.num_qo_heads, self.num_kv_heads, self.head_dim) + self.num_qo_heads, self.num_kv_heads, self.head_dim, + self.page_size) else: - self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, "NHD") + if not hasattr(self, "decode_workspace_buffer"): + # Allocate 16MB workspace buffer + # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html + self.decode_workspace_buffer = torch.empty(16 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + + if self.use_captured_graph: + self.indptr_buffer = torch.empty(self.max_num_seqs + 1, + dtype=torch.int32, + device=self.device) + self.indices_buffer = torch.empty(self.num_gpu_blocks, + dtype=torch.int32, + device=self.device) + self.last_page_len_buffer = torch.empty(self.max_num_seqs, + dtype=torch.int32, + device=self.device) + + if self.use_captured_graph: + self.decode_wrapper = CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self.decode_workspace_buffer, self.indptr_buffer, + self.indices_buffer, self.last_page_len_buffer, "NHD") + else: + self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.decode_workspace_buffer, "NHD") + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + self.decode_wrapper.begin_forward( self.paged_kv_indptr, self.paged_kv_indices, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 9ceda3431b898..0470a27a124d2 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -63,8 +63,6 @@ def get_attn_backend( return TorchSDPABackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning("Eager mode is required for the Flashinfer backend. " - "Please make sure --enforce-eager is set.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend else: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1ab49b1b0cd2d..9876d592e1b2f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -124,8 +124,6 @@ def __init__( # Lazy initialization self.model: nn.Module # Set after load_model - # Set if the backend is flashinfer. - self.flashinfer_workspace_buffer: torch.Tensor # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None @@ -557,19 +555,14 @@ def _prepare_model_input( device=self.device) if self.attn_backend.get_name() == "flashinfer": - if not hasattr(self, "flashinfer_workspace_buffer"): - # Allocate 16MB workspace buffer - # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html - self.flashinfer_workspace_buffer = torch.empty( - 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, dtype=torch.int, - device=self.device) + device='cpu') paged_kv_indices_tensor = torch.tensor(paged_kv_indices, dtype=torch.int, - device=self.device) + device='cpu') paged_kv_last_page_len_tensor = torch.tensor( - paged_kv_last_page_len, dtype=torch.int, device=self.device) + paged_kv_last_page_len, dtype=torch.int, device='cpu') kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype) attn_metadata = self.attn_backend.make_metadata( @@ -580,7 +573,6 @@ def _prepare_model_input( use_cuda_graph=False, max_prefill_seq_len=max_prefill_seq_len, block_tables=block_tables, - workspace_buffer=self.flashinfer_workspace_buffer, paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor, @@ -589,10 +581,14 @@ def _prepare_model_input( num_kv_heads=self.model_config.get_num_kv_heads( self.parallel_config), head_dim=self.model_config.get_head_size(), - page_size=16, + page_size=self.block_size, + num_gpu_blocks=self.cache_config.num_gpu_blocks, + max_num_seqs=self.scheduler_config.max_num_seqs, seq_start_loc=seq_start_loc, query_start_loc=query_start_loc, - data_type=kv_cache_dtype) + device=self.device, + data_type=kv_cache_dtype, + use_captured_graph=use_captured_graph) else: attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, @@ -886,23 +882,60 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): + kv_cache_dtype = get_kv_cache_torch_dtype( + self.kv_cache_dtype, self.model_config.dtype) + paged_kv_indptr_tensor_host = torch.arange(0, batch_size + + 1).int() + paged_kv_indices_tensor_host = torch.arange(0, + batch_size).int() + paged_kv_last_page_len_tensor_host = torch.full( + (batch_size, ), self.block_size, dtype=torch.int32) + query_start_loc_host = torch.arange(0, batch_size + 1).int() + # Create dummy attn_metadata. - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=seq_lens[:batch_size], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables[:batch_size], - use_cuda_graph=True, - ) + if self.attn_backend.get_name() == "flashinfer": + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=slot_mapping[:batch_size], + num_prefill_tokens=0, + num_decode_tokens=batch_size, + use_cuda_graph=False, + max_prefill_seq_len=0, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len= + paged_kv_last_page_len_tensor_host, + num_qo_heads=self.model_config.get_num_attention_heads( + self.parallel_config), + num_kv_heads=self.model_config.get_num_kv_heads( + self.parallel_config), + head_dim=self.model_config.get_head_size(), + page_size=self.block_size, + num_gpu_blocks=self.cache_config.num_gpu_blocks, + max_num_seqs=self.scheduler_config.max_num_seqs, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.device, + data_type=kv_cache_dtype, + use_captured_graph=True) + else: + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + ) if self.lora_config: lora_mapping = LoRAMapping( @@ -911,7 +944,8 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: ) self.set_active_loras(set(), lora_mapping) - graph_runner = CUDAGraphRunner(self.model) + graph_runner = CUDAGraphRunner(self.model, + self.attn_backend.get_name()) graph_runner.capture( input_tokens[:batch_size], input_positions[:batch_size], @@ -935,8 +969,10 @@ def vocab_size(self) -> int: class CUDAGraphRunner: - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module, backend_name: str): self.model = model + self.backend_name = backend_name + self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} @@ -983,14 +1019,24 @@ def capture( torch.cuda.synchronize() # Save the input and output buffers. - self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - } + if self.backend_name == "flashinfer": + self.input_buffers = { + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + "slot_mapping": attn_metadata.slot_mapping, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + else: + self.input_buffers = { + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": + attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } self.output_buffers = {"hidden_states": hidden_states} return @@ -1010,8 +1056,10 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + if self.backend_name != "flashinfer": + self.input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, + non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. From f1849f73469c292d4f73b32ba7ff6c6f7824926e Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 4 Jun 2024 04:17:02 +0000 Subject: [PATCH 10/39] wip --- vllm/attention/backends/flashinfer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 2ed8117463f29..8ca1e52d1f801 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -108,8 +108,8 @@ class FlashInferMetadata(AttentionMetadata): data_type: torch.dtype = None device: torch.device = torch.device("cuda") num_gpu_blocks: Optional[int] = None - max_num_seqs: Optional[int] = None - use_captured_graph: Optional[bool] = None + max_num_seqs: int = 256 + use_captured_graph: bool = True def __post_init__(self): # Refer to @@ -132,6 +132,9 @@ def __post_init__(self): BatchPrefillWithPagedKVCacheWrapper( self.prefill_workspace_buffer, "NHD") assert self.prefill_workspace_buffer is not None + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( @@ -161,12 +164,16 @@ def __post_init__(self): device=self.device) if self.use_captured_graph: - self.decode_wrapper = CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self.decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( self.decode_workspace_buffer, self.indptr_buffer, self.indices_buffer, self.last_page_len_buffer, "NHD") else: self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self.decode_workspace_buffer, "NHD") + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( From 88425a34cc937982679239517f19fcc0ff2c35d0 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 4 Jun 2024 05:58:05 +0000 Subject: [PATCH 11/39] pass tests --- tests/basic_correctness/test_basic_correctness.py | 12 +----------- vllm/attention/backends/flashinfer.py | 4 ++-- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 7bff01d8a58e8..1eaad8023ca45 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -2,7 +2,6 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ -import os import weakref import pytest @@ -13,16 +12,11 @@ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] -VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" def test_vllm_gc_ed(): """Verify vllm instance is GC'ed when it is deleted""" - backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) - enforce_eager = False - if backend_by_env_var == "FLASHINFER": - enforce_eager = True - llm = LLM("facebook/opt-125m", enforce_eager=enforce_eager) + llm = LLM("facebook/opt-125m") weak_llm = weakref.ref(llm) del llm # If there's any circular reference to vllm, this fails @@ -43,10 +37,6 @@ def test_models( max_tokens: int, enforce_eager: bool, ) -> None: - backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) - if backend_by_env_var == "FLASHINFER" and enforce_eager is False: - pytest.skip("Skipping non-eager test for FlashInferBackend.") - hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8ca1e52d1f801..5b6394543b6bb 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -4,8 +4,8 @@ try: import flashinfer from flash_attn import flash_attn_varlen_func - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper except ImportError: flashinfer = None @@ -148,7 +148,7 @@ def __post_init__(self): if not hasattr(self, "decode_workspace_buffer"): # Allocate 16MB workspace buffer # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html - self.decode_workspace_buffer = torch.empty(16 * 1024 * 1024, + self.decode_workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=self.device) From 74a8eebaf188b2b1622547947fae3fd6732a3a79 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Wed, 5 Jun 2024 20:41:57 +0000 Subject: [PATCH 12/39] wip --- vllm/attention/backends/flashinfer.py | 56 ++------- vllm/worker/model_runner.py | 172 ++++++++++++++++++++------ 2 files changed, 142 insertions(+), 86 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5b6394543b6bb..2b9a8e199bf55 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -2,7 +2,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type try: - import flashinfer from flash_attn import flash_attn_varlen_func from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper @@ -69,11 +68,9 @@ class FlashInferMetadata(AttentionMetadata): # requests only. max_prefill_seq_len: int - use_cuda_graph: bool = False + use_cuda_graph: bool = True prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None - decode_cudagraph_wrapper: Optional[ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None # Metadata for the prefill stage @@ -90,12 +87,12 @@ class FlashInferMetadata(AttentionMetadata): # paged_kv_indptr is used to index into paged_kv_indices: # [0, 3, 6, 8] # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: Optional[torch.Tensor] = None + paged_kv_indptr: torch.Tensor = None # The page indices of the paged kv cache - paged_kv_indices: Optional[torch.Tensor] = None + paged_kv_indices: torch.Tensor = None # The number of entries in the last page of each request in # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len: Optional[torch.Tensor] = None + paged_kv_last_page_len: torch.Tensor = None # The number of query/output heads num_qo_heads: Optional[int] = None # The number of key/value heads @@ -107,9 +104,6 @@ class FlashInferMetadata(AttentionMetadata): # The data type of the paged kv cache data_type: torch.dtype = None device: torch.device = torch.device("cuda") - num_gpu_blocks: Optional[int] = None - max_num_seqs: int = 256 - use_captured_graph: bool = True def __post_init__(self): # Refer to @@ -128,13 +122,12 @@ def __post_init__(self): self.prefill_workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device=self.device) - self.prefill_wrapper = \ - BatchPrefillWithPagedKVCacheWrapper( - self.prefill_workspace_buffer, "NHD") + if self.paged_kv_indices is None: + return + + assert self.prefill_wrapper is not None assert self.prefill_workspace_buffer is not None - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( @@ -145,40 +138,13 @@ def __post_init__(self): self.num_qo_heads, self.num_kv_heads, self.head_dim, self.page_size) else: - if not hasattr(self, "decode_workspace_buffer"): - # Allocate 16MB workspace buffer - # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html - self.decode_workspace_buffer = torch.empty(128 * 1024 * 1024, - dtype=torch.uint8, - device=self.device) - - if self.use_captured_graph: - self.indptr_buffer = torch.empty(self.max_num_seqs + 1, - dtype=torch.int32, - device=self.device) - self.indices_buffer = torch.empty(self.num_gpu_blocks, - dtype=torch.int32, - device=self.device) - self.last_page_len_buffer = torch.empty(self.max_num_seqs, - dtype=torch.int32, - device=self.device) - - if self.use_captured_graph: - self.decode_wrapper = \ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - self.decode_workspace_buffer, self.indptr_buffer, - self.indices_buffer, self.last_page_len_buffer, "NHD") - else: - self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.decode_workspace_buffer, "NHD") - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None + if not self.use_cuda_graph: self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) + assert self.decode_wrapper is not None self.decode_wrapper.begin_forward( self.paged_kv_indptr, self.paged_kv_indices, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9876d592e1b2f..bbbb64a28ed82 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -24,6 +24,10 @@ from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) +from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper +from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper +from flashinfer import BatchDecodeWithPagedKVCacheWrapper + logger = init_logger(__name__) _PAD_SLOT_ID = -1 @@ -127,6 +131,12 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.paged_kv_indptr_tensor = None + self.paged_kv_indices_tensor = None + self.paged_kv_last_page_len_tensor = None + self.prefill_wrapper = None + self.decode_wrapper = None + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -245,24 +255,39 @@ def _prepare_model_input( num_prefills = 0 num_prefill_tokens = 0 num_decode_tokens = 0 - - # The following fields are only for flashinfer - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout - # for the precise definition of the following fields. - # An example: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - paged_kv_indices: List[int] = [] - # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. - paged_kv_indptr: List[int] = [0] - # paged_kv_last_page_len is the length of the last page of each request - paged_kv_last_page_len: List[int] = [] + flashinfer_batch_idx = 0 + + if self.attn_backend.get_name() == "flashinfer" and \ + self.paged_kv_indptr_tensor is None and \ + self.cache_config.num_gpu_blocks is not None: + # Preallocate pinned memory to avoid repeated allocations. + # The following fields are only for flashinfer + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indptr_tensor = torch.empty( + self.scheduler_config.max_num_seqs + 1, + dtype=torch.int, + device='cpu') + self.paged_kv_indptr_tensor[0] = 0 + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indices_tensor = torch.empty( + self.cache_config.num_gpu_blocks, + dtype=torch.int, + device='cpu') + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len_tensor = torch.empty( + self.scheduler_config.max_num_seqs, + dtype=torch.int, + device='cpu') if len(seq_group_metadata_list) == 0: return ModelInput.empty(self.device) @@ -451,13 +476,22 @@ def _prepare_model_input( slot_mapping.append(slot) if self.attn_backend.get_name() == "flashinfer": - paged_kv_indices.extend(block_table) - paged_kv_indptr.append(paged_kv_indptr[-1] + - len(block_table)) + + start_idx = self.paged_kv_indptr_tensor[ + flashinfer_batch_idx] + # end_idx = start_idx + len(block_table) + for idx, block in enumerate(block_table): + self.paged_kv_indices_tensor[start_idx + idx] = block + # self.paged_kv_indices_tensor[start_idx:end_idx] = block_table + self.paged_kv_indptr_tensor[ + flashinfer_batch_idx + + 1] = start_idx + len(block_table) last_page_len = seq_len % self.block_size if last_page_len == 0: last_page_len = self.block_size - paged_kv_last_page_len.append(last_page_len) + self.paged_kv_last_page_len_tensor[ + flashinfer_batch_idx] = last_page_len + flashinfer_batch_idx += 1 batch_size = len(input_tokens) max_query_len = max(query_lens) @@ -555,22 +589,62 @@ def _prepare_model_input( device=self.device) if self.attn_backend.get_name() == "flashinfer": - paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, - dtype=torch.int, - device='cpu') - paged_kv_indices_tensor = torch.tensor(paged_kv_indices, - dtype=torch.int, - device='cpu') - paged_kv_last_page_len_tensor = torch.tensor( - paged_kv_last_page_len, dtype=torch.int, device='cpu') + if self.paged_kv_indices_tensor is not None: + end_idx = self.paged_kv_indptr_tensor[ + flashinfer_batch_idx].item() + paged_kv_indices_tensor = self.paged_kv_indices_tensor[: + end_idx] + paged_kv_indptr_tensor = self.paged_kv_indptr_tensor[: + flashinfer_batch_idx + + 1] + paged_kv_last_page_len_tensor = self.paged_kv_last_page_len_tensor[: + flashinfer_batch_idx] + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + + if num_decode_tokens and self.decode_wrapper is None: + self.decode_workspace_buffer = torch.empty(128 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + + if use_captured_graph: + self.indptr_buffer = torch.empty( + self.scheduler_config.max_num_seqs + 1, + dtype=torch.int32, + device=self.device) + self.indices_buffer = torch.empty( + self.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.device) + self.last_page_len_buffer = torch.empty( + self.scheduler_config.max_num_seqs, + dtype=torch.int32, + device=self.device) + self.decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self.decode_workspace_buffer, self.indptr_buffer, + self.indices_buffer, self.last_page_len_buffer, "NHD") + else: + self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.decode_workspace_buffer, "NHD") + + if num_prefill_tokens and self.prefill_wrapper is None: + self.prefill_workspace_buffer = torch.empty(128 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self.prefill_workspace_buffer, "NHD") + kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype) + attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, - use_cuda_graph=False, max_prefill_seq_len=max_prefill_seq_len, block_tables=block_tables, paged_kv_indptr=paged_kv_indptr_tensor, @@ -582,13 +656,15 @@ def _prepare_model_input( self.parallel_config), head_dim=self.model_config.get_head_size(), page_size=self.block_size, - num_gpu_blocks=self.cache_config.num_gpu_blocks, - max_num_seqs=self.scheduler_config.max_num_seqs, seq_start_loc=seq_start_loc, query_start_loc=query_start_loc, device=self.device, data_type=kv_cache_dtype, - use_captured_graph=use_captured_graph) + use_cuda_graph=use_captured_graph, + decode_wrapper=self.decode_wrapper, + prefill_wrapper=self.prefill_wrapper, + ) + else: attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, @@ -878,6 +954,22 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] + indptr_buffer = torch.empty(self.scheduler_config.max_num_seqs + 1, + dtype=torch.int32, + device=self.device) + indices_buffer = torch.empty(self.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.device) + last_page_len_buffer = torch.empty(self.scheduler_config.max_num_seqs, + dtype=torch.int32, + device=self.device) + decode_workspace_buffer = torch.empty(128 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + decode_wrapper = CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + decode_workspace_buffer, indptr_buffer, indices_buffer, + last_page_len_buffer, "NHD") + with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. @@ -899,7 +991,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping=slot_mapping[:batch_size], num_prefill_tokens=0, num_decode_tokens=batch_size, - use_cuda_graph=False, max_prefill_seq_len=0, block_tables=block_tables, paged_kv_indptr=paged_kv_indptr_tensor_host, @@ -912,13 +1003,13 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: self.parallel_config), head_dim=self.model_config.get_head_size(), page_size=self.block_size, - num_gpu_blocks=self.cache_config.num_gpu_blocks, - max_num_seqs=self.scheduler_config.max_num_seqs, seq_start_loc=None, query_start_loc=query_start_loc_host, device=self.device, data_type=kv_cache_dtype, - use_captured_graph=True) + use_cuda_graph=True, + decode_wrapper=decode_wrapper, + prefill_wrapper=None) else: attn_metadata = self.attn_backend.make_metadata( num_prefills=0, @@ -1025,7 +1116,6 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "block_tables": attn_metadata.decode_metadata.block_tables, } else: self.input_buffers = { @@ -1060,8 +1150,8 @@ def forward( self.input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - self.input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) + self.input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. self.graph.replay() From dcbbfd6c0d3214a92f0d589cc838f14c0d3cb5bf Mon Sep 17 00:00:00 2001 From: lilyliu Date: Mon, 10 Jun 2024 23:27:19 -0700 Subject: [PATCH 13/39] pass simple tests, need more fix for correctness --- vllm/worker/model_runner.py | 94 ++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bbbb64a28ed82..dc688462ac821 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -131,6 +131,7 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + # Flashinfer fields self.paged_kv_indptr_tensor = None self.paged_kv_indices_tensor = None self.paged_kv_last_page_len_tensor = None @@ -436,7 +437,8 @@ def _prepare_model_input( multi_modal_input_list.append( seq_group_metadata.multi_modal_data.data) - if _is_block_tables_empty(seq_group_metadata.block_tables): + is_profile_run = _is_block_tables_empty(seq_group_metadata.block_tables) + if is_profile_run: # During memory profiling, the block tables are not # initialized yet. In this case, we just use a dummy # slot mapping. @@ -475,17 +477,20 @@ def _prepare_model_input( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - if self.attn_backend.get_name() == "flashinfer": - + if self.attn_backend.get_name() == "flashinfer": start_idx = self.paged_kv_indptr_tensor[ flashinfer_batch_idx] - # end_idx = start_idx + len(block_table) - for idx, block in enumerate(block_table): + block_table_bound = seq_data.get_len( + ) // self.block_size + 1 if seq_data.get_len( + ) % self.block_size != 0 else seq_data.get_len( + ) // self.block_size + # end_idx = start_idx + block_table_bound + for idx, block in enumerate(block_table[:block_table_bound]): self.paged_kv_indices_tensor[start_idx + idx] = block # self.paged_kv_indices_tensor[start_idx:end_idx] = block_table self.paged_kv_indptr_tensor[ flashinfer_batch_idx + - 1] = start_idx + len(block_table) + 1] = start_idx + block_table_bound last_page_len = seq_len % self.block_size if last_page_len == 0: last_page_len = self.block_size @@ -604,30 +609,16 @@ def _prepare_model_input( paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None - if num_decode_tokens and self.decode_wrapper is None: - self.decode_workspace_buffer = torch.empty(128 * 1024 * 1024, + + if num_decode_tokens: + if use_captured_graph and not is_profile_run: + self.decode_wrapper = self.graph_runners[batch_size].decode_wrapper + else: + if self.decode_wrapper is None: + self.decode_workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=self.device) - - if use_captured_graph: - self.indptr_buffer = torch.empty( - self.scheduler_config.max_num_seqs + 1, - dtype=torch.int32, - device=self.device) - self.indices_buffer = torch.empty( - self.cache_config.num_gpu_blocks, - dtype=torch.int32, - device=self.device) - self.last_page_len_buffer = torch.empty( - self.scheduler_config.max_num_seqs, - dtype=torch.int32, - device=self.device) - self.decode_wrapper = \ - CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - self.decode_workspace_buffer, self.indptr_buffer, - self.indices_buffer, self.last_page_len_buffer, "NHD") - else: - self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self.decode_workspace_buffer, "NHD") if num_prefill_tokens and self.prefill_wrapper is None: @@ -954,26 +945,27 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - indptr_buffer = torch.empty(self.scheduler_config.max_num_seqs + 1, - dtype=torch.int32, - device=self.device) - indices_buffer = torch.empty(self.cache_config.num_gpu_blocks, - dtype=torch.int32, - device=self.device) - last_page_len_buffer = torch.empty(self.scheduler_config.max_num_seqs, - dtype=torch.int32, - device=self.device) - decode_workspace_buffer = torch.empty(128 * 1024 * 1024, - dtype=torch.uint8, - device=self.device) - decode_wrapper = CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - decode_workspace_buffer, indptr_buffer, indices_buffer, - last_page_len_buffer, "NHD") - with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): + print("Capture-----------------", batch_size) + indptr_buffer = torch.empty(batch_size + 1, + dtype=torch.int32, + device=self.device) + indices_buffer = torch.empty(self.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.device) + last_page_len_buffer = torch.empty(batch_size, + dtype=torch.int32, + device=self.device) + decode_workspace_buffer = torch.empty(128 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + decode_wrapper = CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + decode_workspace_buffer, indptr_buffer, indices_buffer, + last_page_len_buffer, "NHD") + kv_cache_dtype = get_kv_cache_torch_dtype( self.kv_cache_dtype, self.model_config.dtype) paged_kv_indptr_tensor_host = torch.arange(0, batch_size + @@ -1037,6 +1029,13 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: graph_runner = CUDAGraphRunner(self.model, self.attn_backend.get_name()) + + graph_runner.indptr_buffer = indptr_buffer + graph_runner.indices_buffer = indices_buffer + graph_runner.last_page_len_buffer = last_page_len_buffer + graph_runner.decode_workspace_buffer = decode_workspace_buffer + graph_runner.decode_wrapper = decode_wrapper + graph_runner.capture( input_tokens[:batch_size], input_positions[:batch_size], @@ -1069,6 +1068,13 @@ def __init__(self, model: nn.Module, backend_name: str): self._graph: Optional[torch.cuda.CUDAGraph] = None + # Flashinfer fields + self.decode_workspace_buffer = None + self.indptr_buffer = None + self.indices_buffer = None + self.last_page_len_buffer = None + self.decode_wrapper = None + @property def graph(self): assert self._graph is not None From 430284846149ad8f182c41ad9c5571c366760bd6 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 13 Jun 2024 20:46:22 +0000 Subject: [PATCH 14/39] optimizer prepare input --- vllm/attention/backends/flashinfer.py | 1 - vllm/worker/model_runner.py | 149 +++++++++++--------------- 2 files changed, 61 insertions(+), 89 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 2b9a8e199bf55..1e8e3affbb2cb 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -4,7 +4,6 @@ try: from flash_attn import flash_attn_varlen_func from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper except ImportError: flashinfer = None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index dc688462ac821..1e7a9d7c449e9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -5,6 +5,9 @@ import numpy as np import torch import torch.nn as nn +from flashinfer import BatchDecodeWithPagedKVCacheWrapper +from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper +from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, @@ -24,10 +27,6 @@ from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) -from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper -from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper -from flashinfer import BatchDecodeWithPagedKVCacheWrapper - logger = init_logger(__name__) _PAD_SLOT_ID = -1 @@ -36,7 +35,7 @@ # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 10) ] @@ -256,39 +255,24 @@ def _prepare_model_input( num_prefills = 0 num_prefill_tokens = 0 num_decode_tokens = 0 - flashinfer_batch_idx = 0 - - if self.attn_backend.get_name() == "flashinfer" and \ - self.paged_kv_indptr_tensor is None and \ - self.cache_config.num_gpu_blocks is not None: - # Preallocate pinned memory to avoid repeated allocations. - # The following fields are only for flashinfer - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout - # for the precise definition of the following fields. - # An example: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - self.paged_kv_indptr_tensor = torch.empty( - self.scheduler_config.max_num_seqs + 1, - dtype=torch.int, - device='cpu') - self.paged_kv_indptr_tensor[0] = 0 - # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. - self.paged_kv_indices_tensor = torch.empty( - self.cache_config.num_gpu_blocks, - dtype=torch.int, - device='cpu') - # paged_kv_last_page_len is the length of the last page of each request - self.paged_kv_last_page_len_tensor = torch.empty( - self.scheduler_config.max_num_seqs, - dtype=torch.int, - device='cpu') + + # The following fields are only for flashinfer + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + paged_kv_indptr = [0] + paged_kv_indices = [] + # paged_kv_last_page_len is the length of the last page of each request + paged_kv_last_page_len = [] if len(seq_group_metadata_list) == 0: return ModelInput.empty(self.device) @@ -437,7 +421,8 @@ def _prepare_model_input( multi_modal_input_list.append( seq_group_metadata.multi_modal_data.data) - is_profile_run = _is_block_tables_empty(seq_group_metadata.block_tables) + is_profile_run = _is_block_tables_empty( + seq_group_metadata.block_tables) if is_profile_run: # During memory profiling, the block tables are not # initialized yet. In this case, we just use a dummy @@ -477,26 +462,20 @@ def _prepare_model_input( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - if self.attn_backend.get_name() == "flashinfer": - start_idx = self.paged_kv_indptr_tensor[ - flashinfer_batch_idx] - block_table_bound = seq_data.get_len( - ) // self.block_size + 1 if seq_data.get_len( - ) % self.block_size != 0 else seq_data.get_len( - ) // self.block_size - # end_idx = start_idx + block_table_bound - for idx, block in enumerate(block_table[:block_table_bound]): - self.paged_kv_indices_tensor[start_idx + idx] = block - # self.paged_kv_indices_tensor[start_idx:end_idx] = block_table - self.paged_kv_indptr_tensor[ - flashinfer_batch_idx + - 1] = start_idx + block_table_bound + if self.attn_backend.get_name() == "flashinfer": + seq_len = seq_data.get_len() + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + + paged_kv_indices.extend(block_table[:block_table_bound]) + paged_kv_indptr.append(paged_kv_indptr[-1] + + block_table_bound) + last_page_len = seq_len % self.block_size if last_page_len == 0: last_page_len = self.block_size - self.paged_kv_last_page_len_tensor[ - flashinfer_batch_idx] = last_page_len - flashinfer_batch_idx += 1 + paged_kv_last_page_len.append(last_page_len) batch_size = len(input_tokens) max_query_len = max(query_lens) @@ -565,10 +544,6 @@ def _prepare_model_input( query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) @@ -594,31 +569,30 @@ def _prepare_model_input( device=self.device) if self.attn_backend.get_name() == "flashinfer": - if self.paged_kv_indices_tensor is not None: - end_idx = self.paged_kv_indptr_tensor[ - flashinfer_batch_idx].item() - paged_kv_indices_tensor = self.paged_kv_indices_tensor[: - end_idx] - paged_kv_indptr_tensor = self.paged_kv_indptr_tensor[: - flashinfer_batch_idx - + 1] - paged_kv_last_page_len_tensor = self.paged_kv_last_page_len_tensor[: - flashinfer_batch_idx] + if len(paged_kv_indptr) > 0: + paged_kv_indices_tensor = torch.tensor(paged_kv_indices, + device='cpu', + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, + device='cpu', + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + paged_kv_last_page_len, device='cpu', dtype=torch.int) else: paged_kv_indices_tensor = None paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None - if num_decode_tokens: if use_captured_graph and not is_profile_run: - self.decode_wrapper = self.graph_runners[batch_size].decode_wrapper - else: - if self.decode_wrapper is None: - self.decode_workspace_buffer = torch.empty(128 * 1024 * 1024, - dtype=torch.uint8, - device=self.device) - self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.decode_wrapper = self.graph_runners[ + batch_size].decode_wrapper + elif self.decode_wrapper is None: + self.decode_workspace_buffer = torch.empty( + 128 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self.decode_workspace_buffer, "NHD") if num_prefill_tokens and self.prefill_wrapper is None: @@ -949,23 +923,22 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): - print("Capture-----------------", batch_size) indptr_buffer = torch.empty(batch_size + 1, - dtype=torch.int32, - device=self.device) - indices_buffer = torch.empty(self.cache_config.num_gpu_blocks, dtype=torch.int32, device=self.device) + indices_buffer = torch.empty(self.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.device) last_page_len_buffer = torch.empty(batch_size, - dtype=torch.int32, - device=self.device) + dtype=torch.int32, + device=self.device) decode_workspace_buffer = torch.empty(128 * 1024 * 1024, - dtype=torch.uint8, - device=self.device) + dtype=torch.uint8, + device=self.device) decode_wrapper = CUDAGraphBatchDecodeWithPagedKVCacheWrapper( decode_workspace_buffer, indptr_buffer, indices_buffer, last_page_len_buffer, "NHD") - + kv_cache_dtype = get_kv_cache_torch_dtype( self.kv_cache_dtype, self.model_config.dtype) paged_kv_indptr_tensor_host = torch.arange(0, batch_size + @@ -1029,7 +1002,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: graph_runner = CUDAGraphRunner(self.model, self.attn_backend.get_name()) - + graph_runner.indptr_buffer = indptr_buffer graph_runner.indices_buffer = indices_buffer graph_runner.last_page_len_buffer = last_page_len_buffer From d7393127821c37b78db2e3c412cb47e855b728b5 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 13 Jun 2024 21:12:43 +0000 Subject: [PATCH 15/39] padding --- vllm/worker/model_runner.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1e7a9d7c449e9..d7c25eaca788e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -499,6 +499,12 @@ def _prepare_model_input( seq_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) + + last_paged_kv_indptr = paged_kv_indptr[-1] + if self.attn_backend.get_name() == "flashinfer": + paged_kv_indptr.append(last_paged_kv_indptr) + paged_kv_last_page_len.append(0) + batch_size = graph_batch_size num_decode_tokens = batch_size From 5ad175ae15c5e50c309f5ac76c7efa93a5b5716f Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Thu, 13 Jun 2024 21:47:00 +0000 Subject: [PATCH 16/39] style --- Dockerfile | 1 - tests/basic_correctness/test_basic_correctness.py | 2 +- vllm/worker/model_runner.py | 9 ++++----- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index 4b36266ee9c4c..eb96bf3c1db2b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -99,7 +99,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ pip install dist/*.whl --verbose - #################### vLLM installation IMAGE #################### diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 3621d6ce9a516..6f44030feebb0 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -10,7 +10,7 @@ MODELS = [ "facebook/opt-125m", - # "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-hf", ] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bd60d5ead5d75..b43eb9af1277e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -38,7 +38,7 @@ # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 10) + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) ] _NUM_WARMUP_ITERS = 2 @@ -292,12 +292,12 @@ def _prepare_model_input( # [0, 5, 8, 1, 6, 7, 3, 4] # paged_kv_indptr is used to index into paged_kv_indices: # [0, 3, 6, 8] + paged_kv_indices: List[int] = [] # 0 at the beginning of paged_kv_indptr indicates the start of the # first request’s page indices in the paged_kv_indices list. - paged_kv_indptr = [0] - paged_kv_indices = [] + paged_kv_indptr: List[int] = [0] # paged_kv_last_page_len is the length of the last page of each request - paged_kv_last_page_len = [] + paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: return ModelInput.empty(self.device) @@ -408,7 +408,6 @@ def _prepare_model_input( else: # Prefill without chunked prefill or memory profiling. block_table = [] - block_tables.append(block_table) seq_lens.append(sliding_seq_len) From 543dc3b4e4b361e77c0cbae37cbe0ee493775306 Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Fri, 14 Jun 2024 00:35:02 +0000 Subject: [PATCH 17/39] share workspace buffer to reduce cudagraph extra memory cost --- vllm/worker/model_runner.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b43eb9af1277e..812a87e32689b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -581,6 +581,10 @@ def _prepare_model_input( dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) input_tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, @@ -655,21 +659,6 @@ def _prepare_model_input( ) else: - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, slot_mapping=slot_mapping_tensor, @@ -974,6 +963,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] + # For flashinfer, different batch sizes will share the same workspace buffer + decode_workspace_buffer = torch.empty(128 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. @@ -987,9 +980,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: last_page_len_buffer = torch.empty(batch_size, dtype=torch.int32, device=self.device) - decode_workspace_buffer = torch.empty(128 * 1024 * 1024, - dtype=torch.uint8, - device=self.device) decode_wrapper = CUDAGraphBatchDecodeWithPagedKVCacheWrapper( decode_workspace_buffer, indptr_buffer, indices_buffer, last_page_len_buffer, "NHD") From 11b7347be82ef15b7dac67764c606a64df2c790f Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 17 Jun 2024 02:47:46 +0000 Subject: [PATCH 18/39] address comments --- .../test_basic_distributed_correctness.py | 5 - vllm/attention/backends/flashinfer.py | 20 ++-- vllm/worker/model_runner.py | 109 +++++++++++------- 3 files changed, 75 insertions(+), 59 deletions(-) diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index eb423aef230cb..b8ae5b4c44f8d 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -21,7 +21,6 @@ os.environ["TEST_DIST_MODEL"], ] DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" -VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -39,16 +38,12 @@ def test_models( ) -> None: distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) - backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) - enforce_eager = backend_by_env_var == "FLASHINFER" - with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner(model, dtype=dtype, tensor_parallel_size=2, - enforce_eager=enforce_eager, distributed_executor_backend=distributed_executor_backend ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 1e8e3affbb2cb..542c85ed36d3e 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -86,12 +86,12 @@ class FlashInferMetadata(AttentionMetadata): # paged_kv_indptr is used to index into paged_kv_indices: # [0, 3, 6, 8] # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: torch.Tensor = None + paged_kv_indptr: Optional[torch.Tensor] = None # The page indices of the paged kv cache - paged_kv_indices: torch.Tensor = None + paged_kv_indices: Optional[torch.Tensor] = None # The number of entries in the last page of each request in # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len: torch.Tensor = None + paged_kv_last_page_len: Optional[torch.Tensor] = None # The number of query/output heads num_qo_heads: Optional[int] = None # The number of key/value heads @@ -115,18 +115,13 @@ def __post_init__(self): f"received {self.head_dim}.") if self.num_prefill_tokens > 0: - if not hasattr(self, "prefill_workspace_buffer"): - # Allocate 16MB workspace buffer - # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html - self.prefill_workspace_buffer = torch.empty(16 * 1024 * 1024, - dtype=torch.uint8, - device=self.device) if self.paged_kv_indices is None: return assert self.prefill_wrapper is not None - assert self.prefill_workspace_buffer is not None - + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( @@ -138,6 +133,9 @@ def __post_init__(self): self.page_size) else: if not self.use_cuda_graph: + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 812a87e32689b..f9d1f8967ba4e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -7,9 +7,17 @@ import numpy as np import torch import torch.nn as nn -from flashinfer import BatchDecodeWithPagedKVCacheWrapper -from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper -from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 +except ImportError: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, @@ -145,11 +153,11 @@ def __init__( self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None # Flashinfer fields - self.paged_kv_indptr_tensor = None - self.paged_kv_indices_tensor = None - self.paged_kv_last_page_len_tensor = None - self.prefill_wrapper = None - self.decode_wrapper = None + self.flashinfer_prefill_wrapper: Optional[ + BatchPrefillWithPagedKVCacheWrapper] = None + self.flashinfer_decode_wrapper: Optional[ + Union[BatchDecodeWithPagedKVCacheWrapper, + CUDAGraphBatchDecodeWithPagedKVCacheWrapper]] = None def load_model(self) -> None: with CudaMemoryProfiler() as m: @@ -494,8 +502,14 @@ def _prepare_model_input( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + # Prepare input tensors for flashinfer if self.attn_backend.get_name() == "flashinfer": seq_len = seq_data.get_len() + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. block_table_bound = seq_len // self.block_size + 1 \ if seq_len % self.block_size != 0 \ else seq_len // self.block_size @@ -532,8 +546,8 @@ def _prepare_model_input( block_tables.append([]) lora_index_mapping.append(0) - last_paged_kv_indptr = paged_kv_indptr[-1] if self.attn_backend.get_name() == "flashinfer": + last_paged_kv_indptr = paged_kv_indptr[-1] paged_kv_indptr.append(last_paged_kv_indptr) paged_kv_last_page_len.append(0) @@ -582,9 +596,9 @@ def _prepare_model_input( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) input_tokens_tensor = torch.tensor(input_tokens, dtype=torch.long, @@ -613,22 +627,25 @@ def _prepare_model_input( if num_decode_tokens: if use_captured_graph and not is_profile_run: - self.decode_wrapper = self.graph_runners[ - batch_size].decode_wrapper - elif self.decode_wrapper is None: - self.decode_workspace_buffer = torch.empty( - 128 * 1024 * 1024, + self.flashinfer_decode_wrapper = self.graph_runners[ + batch_size].flashinfer_decode_wrapper + elif self.flashinfer_decode_wrapper is None: + self.flashinfer_decode_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device) - self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.decode_workspace_buffer, "NHD") - - if num_prefill_tokens and self.prefill_wrapper is None: - self.prefill_workspace_buffer = torch.empty(128 * 1024 * 1024, - dtype=torch.uint8, - device=self.device) - self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self.prefill_workspace_buffer, "NHD") + self.flashinfer_decode_wrapper = \ + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_decode_workspace_buffer, "NHD") + + if num_prefill_tokens and self.flashinfer_prefill_wrapper is None: + self.flashinfer_prefill_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_prefill_wrapper = \ + BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_prefill_workspace_buffer, "NHD") kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype) @@ -654,8 +671,8 @@ def _prepare_model_input( device=self.device, data_type=kv_cache_dtype, use_cuda_graph=use_captured_graph, - decode_wrapper=self.decode_wrapper, - prefill_wrapper=self.prefill_wrapper, + decode_wrapper=self.flashinfer_decode_wrapper, + prefill_wrapper=self.flashinfer_prefill_wrapper, ) else: @@ -963,10 +980,12 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - # For flashinfer, different batch sizes will share the same workspace buffer - decode_workspace_buffer = torch.empty(128 * 1024 * 1024, - dtype=torch.uint8, - device=self.device) + # For flashinfer, different batch sizes will share the + # same workspace buffer. + decode_workspace_buffer = \ + torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. @@ -1048,11 +1067,15 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: graph_runner = CUDAGraphRunner(self.model, self.attn_backend.get_name()) - graph_runner.indptr_buffer = indptr_buffer - graph_runner.indices_buffer = indices_buffer - graph_runner.last_page_len_buffer = last_page_len_buffer - graph_runner.decode_workspace_buffer = decode_workspace_buffer - graph_runner.decode_wrapper = decode_wrapper + if self.attn_backend.get_name() == "flashinfer": + graph_runner.flashinfer_indptr_buffer = indptr_buffer + graph_runner.flashinfer_indices_buffer = indices_buffer + graph_runner.flashinfer_last_page_len_buffer = \ + last_page_len_buffer + graph_runner.flashinfer_decode_workspace_buffer = \ + decode_workspace_buffer + graph_runner.flashinfer_decode_wrapper = \ + decode_wrapper graph_runner.capture( input_tokens[:batch_size], @@ -1088,12 +1111,12 @@ def __init__(self, model: nn.Module, backend_name: str): self._graph: Optional[torch.cuda.CUDAGraph] = None - # Flashinfer fields - self.decode_workspace_buffer = None - self.indptr_buffer = None - self.indices_buffer = None - self.last_page_len_buffer = None - self.decode_wrapper = None + self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None + self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None + self.flashinfer_indices_buffer: Optional[torch.Tensor] = None + self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None + self.flashinfer_decode_wrapper: Optional[ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None @property def graph(self): From b5db4be1a7414678540304075df3e0bec374bb0a Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Mon, 17 Jun 2024 04:26:50 +0000 Subject: [PATCH 19/39] fix --- vllm/worker/model_runner.py | 59 ++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f9d1f8967ba4e..43149d3db74af 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -980,38 +980,43 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - # For flashinfer, different batch sizes will share the - # same workspace buffer. - decode_workspace_buffer = \ - torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, + if self.attn_backend.get_name() == "flashinfer": + # For flashinfer, different batch sizes will share the + # same workspace buffer. + decode_workspace_buffer = \ + torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, device=self.device) with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): - indptr_buffer = torch.empty(batch_size + 1, - dtype=torch.int32, - device=self.device) - indices_buffer = torch.empty(self.cache_config.num_gpu_blocks, - dtype=torch.int32, - device=self.device) - last_page_len_buffer = torch.empty(batch_size, - dtype=torch.int32, - device=self.device) - decode_wrapper = CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - decode_workspace_buffer, indptr_buffer, indices_buffer, - last_page_len_buffer, "NHD") - - kv_cache_dtype = get_kv_cache_torch_dtype( - self.kv_cache_dtype, self.model_config.dtype) - paged_kv_indptr_tensor_host = torch.arange(0, batch_size + - 1).int() - paged_kv_indices_tensor_host = torch.arange(0, - batch_size).int() - paged_kv_last_page_len_tensor_host = torch.full( - (batch_size, ), self.block_size, dtype=torch.int32) - query_start_loc_host = torch.arange(0, batch_size + 1).int() + if self.attn_backend.get_name() == "flashinfer": + indptr_buffer = torch.empty(batch_size + 1, + dtype=torch.int32, + device=self.device) + indices_buffer = torch.empty( + self.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.device) + last_page_len_buffer = torch.empty(batch_size, + dtype=torch.int32, + device=self.device) + decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + decode_workspace_buffer, indptr_buffer, indices_buffer, + last_page_len_buffer, "NHD") + kv_cache_dtype = get_kv_cache_torch_dtype( + self.kv_cache_dtype, self.model_config.dtype) + + paged_kv_indptr_tensor_host = torch.arange( + 0, batch_size + 1).int() + paged_kv_indices_tensor_host = torch.arange( + 0, batch_size).int() + paged_kv_last_page_len_tensor_host = torch.full( + (batch_size, ), self.block_size, dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1).int() # Create dummy attn_metadata. if self.attn_backend.get_name() == "flashinfer": From f53d03eb80d9781a0974b66686faaa11e6e3e4ee Mon Sep 17 00:00:00 2001 From: LiuXiaoxuanPKU Date: Tue, 18 Jun 2024 08:05:47 +0000 Subject: [PATCH 20/39] fix comments --- vllm/worker/model_runner.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 43149d3db74af..81576facf780f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1010,16 +1010,15 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: self.kv_cache_dtype, self.model_config.dtype) paged_kv_indptr_tensor_host = torch.arange( - 0, batch_size + 1).int() + 0, batch_size + 1, dtype=torch.int32) paged_kv_indices_tensor_host = torch.arange( - 0, batch_size).int() + 0, batch_size, dtype=torch.int32) paged_kv_last_page_len_tensor_host = torch.full( (batch_size, ), self.block_size, dtype=torch.int32) query_start_loc_host = torch.arange(0, - batch_size + 1).int() + batch_size + 1, + dtype=torch.int32) - # Create dummy attn_metadata. - if self.attn_backend.get_name() == "flashinfer": attn_metadata = self.attn_backend.make_metadata( num_prefills=0, slot_mapping=slot_mapping[:batch_size], From e05ff792ef276a62f5e347324402ed6c8d77d0b3 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Wed, 19 Jun 2024 16:12:12 -0700 Subject: [PATCH 21/39] support TP > 1 --- vllm/attention/backends/flashinfer.py | 1 + vllm/worker/embedding_model_runner.py | 18 +-- vllm/worker/model_runner.py | 158 +++++++++++++++----------- 3 files changed, 95 insertions(+), 82 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 542c85ed36d3e..75caa98c84d75 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -114,6 +114,7 @@ def __post_init__(self): f"Only {supported_head_sizes} are supported for head_dim,", f"received {self.head_dim}.") + def begin_forward(self): if self.num_prefill_tokens > 0: if self.paged_kv_indices is None: return diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 465130d10e2f9..44d6b8c398266 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -94,20 +94,10 @@ def prepare_input_tensors( if self.is_driver_worker: assert seq_group_metadata_list is not None # Prepare input tensors. - ( - input_tokens, - input_positions, - attn_metadata, - seq_lens, - _, - lora_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - num_prefill_tokens, - num_decode_tokens, - num_prefills, - ) = self._prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, seq_lens, _, + lora_mapping, lora_requests, multi_modal_kwargs, slot_mapping, + num_prefill_tokens, num_decode_tokens, num_prefills, + _) = self._prepare_model_input(seq_group_metadata_list) # Prepare PoolingMetadata pooling_metadata = self._prepare_pooling(seq_group_metadata_list, seq_lens) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f89ddf4f7d61c..2427b3aca5592 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -64,23 +64,23 @@ class ModelInput(NamedTuple): num_prefill_tokens: int num_decode_tokens: int num_prefills: int + is_profile_run: bool @classmethod def empty(cls, device): - return ModelInput( - input_tokens=torch.empty(0, device=device), - input_positions=torch.empty(0, device=device), - attn_metadata=None, - seq_lens=[], - query_lens=[], - lora_mapping=None, - lora_requests=set(), - multi_modal_kwargs={}, - slot_mapping=torch.empty(0, device=device), - num_prefill_tokens=0, - num_decode_tokens=0, - num_prefills=0, - ) + return ModelInput(input_tokens=torch.empty(0, device=device), + input_positions=torch.empty(0, device=device), + attn_metadata=None, + seq_lens=[], + query_lens=[], + lora_mapping=None, + lora_requests=set(), + multi_modal_kwargs={}, + slot_mapping=torch.empty(0, device=device), + num_prefill_tokens=0, + num_decode_tokens=0, + num_prefills=0, + is_profile_run=False) class ModelRunner: @@ -625,28 +625,6 @@ def _prepare_model_input( paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None - if num_decode_tokens: - if use_captured_graph and not is_profile_run: - self.flashinfer_decode_wrapper = self.graph_runners[ - batch_size].flashinfer_decode_wrapper - elif self.flashinfer_decode_wrapper is None: - self.flashinfer_decode_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_decode_wrapper = \ - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_decode_workspace_buffer, "NHD") - - if num_prefill_tokens and self.flashinfer_prefill_wrapper is None: - self.flashinfer_prefill_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_prefill_wrapper = \ - BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_prefill_workspace_buffer, "NHD") - kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype) @@ -670,10 +648,7 @@ def _prepare_model_input( query_start_loc=query_start_loc, device=self.device, data_type=kv_cache_dtype, - use_cuda_graph=use_captured_graph, - decode_wrapper=self.flashinfer_decode_wrapper, - prefill_wrapper=self.flashinfer_prefill_wrapper, - ) + use_cuda_graph=use_captured_graph) else: attn_metadata = self.attn_backend.make_metadata( @@ -706,20 +681,19 @@ def _prepare_model_input( for k, v in multi_modal_kwargs_list.items() } - return ModelInput( - input_tokens=input_tokens_tensor, - input_positions=input_positions_tensor, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_mapping=lora_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, - ) + return ModelInput(input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + is_profile_run=is_profile_run) def prepare_input_tensors( self, @@ -729,24 +703,17 @@ def prepare_input_tensors( if self.is_driver_worker: assert seq_group_metadata_list is not None # Prepare input tensors. - ( - input_tokens, - input_positions, - attn_metadata, - seq_lens, - query_lens, - lora_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - num_prefill_tokens, - num_decode_tokens, - num_prefills, - ) = self._prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, seq_lens, + query_lens, lora_mapping, lora_requests, multi_modal_kwargs, + slot_mapping, num_prefill_tokens, num_decode_tokens, num_prefills, + is_profile_run + ) = self._prepare_model_input(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) + use_cuda_graph = attn_metadata.use_cuda_graph \ + if attn_metadata else False metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -759,10 +726,23 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, + "use_cuda_graph": use_cuda_graph, + "is_profile_run": is_profile_run } if attn_metadata: metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) + + if self.attn_backend.get_name() == "flashinfer": + self._create_flashinfer_wrapper(metadata_dict, + attn_metadata.use_cuda_graph, + is_profile_run, + input_tokens.shape[0]) + attn_metadata.prefill_wrapper = metadata_dict.pop( + 'prefill_wrapper', None) + attn_metadata.decode_wrapper = metadata_dict.pop( + 'decode_wrapper', None) + attn_metadata.begin_forward() else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") @@ -772,9 +752,19 @@ def prepare_input_tensors( lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") + + use_cuda_graph = metadata_dict.pop('use_cuda_graph') + is_profile_run = metadata_dict.pop('is_profile_run') + if self.attn_backend.get_name() == "flashinfer": + self._create_flashinfer_wrapper(metadata_dict, use_cuda_graph, + is_profile_run, + input_tokens.shape[0]) + if metadata_dict: attn_metadata = self.attn_backend.make_metadata( **metadata_dict) + if self.attn_backend.get_name() == "flashinfer": + attn_metadata.begin_forward() else: attn_metadata = None sampling_metadata = SamplingMetadata( @@ -788,6 +778,37 @@ def prepare_input_tensors( sampling_metadata, lora_requests, lora_mapping, multi_modal_kwargs) + def _create_flashinfer_wrapper(self, metadata_dict: Optional[Dict], + use_cuda_graph: bool, is_profile_run: bool, + batch_size: int): + if metadata_dict is None: + return + + if metadata_dict['num_decode_tokens']: + if use_cuda_graph and not is_profile_run: + self.flashinfer_decode_wrapper = self.graph_runners[ + batch_size].flashinfer_decode_wrapper + elif self.flashinfer_decode_wrapper is None: + self.flashinfer_decode_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_decode_wrapper = \ + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_decode_workspace_buffer, "NHD") + metadata_dict['decode_wrapper'] = self.flashinfer_decode_wrapper + + if metadata_dict['num_prefill_tokens']: + if self.flashinfer_prefill_wrapper is None: + self.flashinfer_prefill_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_prefill_wrapper = \ + BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_prefill_workspace_buffer, "NHD") + metadata_dict['prefill_wrapper'] = self.flashinfer_prefill_wrapper + @torch.inference_mode() def execute_model( self, @@ -1043,6 +1064,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: use_cuda_graph=True, decode_wrapper=decode_wrapper, prefill_wrapper=None) + attn_metadata.begin_forward() else: attn_metadata = self.attn_backend.make_metadata( num_prefills=0, From 8f685ddb6b03c309134ea92b904d1d9d5f63ff74 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 20 Jun 2024 13:00:24 -0700 Subject: [PATCH 22/39] try CI --- .buildkite/test-pipeline.yaml | 3 +++ requirements-test.txt | 3 +++ vllm/worker/model_runner.py | 17 ++++++++++++----- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 5afe3730210e8..3e30345a58e7f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -18,6 +18,7 @@ steps: commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py @@ -40,6 +41,8 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py diff --git a/requirements-test.txt b/requirements-test.txt index fef0ede7be0ff..69d314d78440b 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -21,3 +21,6 @@ aiohttp # quantization bitsandbytes==0.42.0 + +# Flashinfer +https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl \ No newline at end of file diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2427b3aca5592..4e0b1ce4b2e12 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1023,10 +1023,19 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: last_page_len_buffer = torch.empty(batch_size, dtype=torch.int32, device=self.device) + + num_qo_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + num_kv_heads = self.model_config.get_num_kv_heads( + self.parallel_config) + if num_qo_heads // num_kv_heads >= 4: + use_tensor_cores = False + else: + use_tensor_cores = False decode_wrapper = \ CUDAGraphBatchDecodeWithPagedKVCacheWrapper( decode_workspace_buffer, indptr_buffer, indices_buffer, - last_page_len_buffer, "NHD") + last_page_len_buffer, "NHD", use_tensor_cores) kv_cache_dtype = get_kv_cache_torch_dtype( self.kv_cache_dtype, self.model_config.dtype) @@ -1051,10 +1060,8 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: paged_kv_indices=paged_kv_indices_tensor_host, paged_kv_last_page_len= paged_kv_last_page_len_tensor_host, - num_qo_heads=self.model_config.get_num_attention_heads( - self.parallel_config), - num_kv_heads=self.model_config.get_num_kv_heads( - self.parallel_config), + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, head_dim=self.model_config.get_head_size(), page_size=self.block_size, seq_start_loc=None, From 0f8e7a19d39d31747c9dd25e7002f19da98edcb2 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 20 Jun 2024 13:00:46 -0700 Subject: [PATCH 23/39] minor --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4e0b1ce4b2e12..5317ad6f3d4c4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1029,7 +1029,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: num_kv_heads = self.model_config.get_num_kv_heads( self.parallel_config) if num_qo_heads // num_kv_heads >= 4: - use_tensor_cores = False + use_tensor_cores = True else: use_tensor_cores = False decode_wrapper = \ From cf275a144f49a69d80d7eaa84db82bf0bc65bfde Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 20 Jun 2024 13:03:07 -0700 Subject: [PATCH 24/39] format --- vllm/worker/model_runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5317ad6f3d4c4..4ee2bdcc09c2e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1023,13 +1023,13 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: last_page_len_buffer = torch.empty(batch_size, dtype=torch.int32, device=self.device) - + num_qo_heads = self.model_config.get_num_attention_heads( - self.parallel_config) + self.parallel_config) num_kv_heads = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config) if num_qo_heads // num_kv_heads >= 4: - use_tensor_cores = True + use_tensor_cores = False else: use_tensor_cores = False decode_wrapper = \ From 0ab32ee959dfcb0f8e172a04b85c7aeca004c0b9 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 20 Jun 2024 13:04:10 -0700 Subject: [PATCH 25/39] minor --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4ee2bdcc09c2e..db36ceabb0c32 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1029,7 +1029,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: num_kv_heads = self.model_config.get_num_kv_heads( self.parallel_config) if num_qo_heads // num_kv_heads >= 4: - use_tensor_cores = False + use_tensor_cores = True else: use_tensor_cores = False decode_wrapper = \ From c421f1fd6b81a40a1fad791b7672e8950c788649 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 20 Jun 2024 13:23:58 -0700 Subject: [PATCH 26/39] try CI --- .buildkite/test-pipeline.yaml | 2 ++ requirements-test.txt | 5 +---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3e30345a58e7f..8e995ec1ab4e9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -18,6 +18,7 @@ steps: commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py @@ -41,6 +42,7 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py diff --git a/requirements-test.txt b/requirements-test.txt index 69d314d78440b..29df45de7a14a 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -20,7 +20,4 @@ torchvision # required for the image processor of phi3v aiohttp # quantization -bitsandbytes==0.42.0 - -# Flashinfer -https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl \ No newline at end of file +bitsandbytes==0.42.0 \ No newline at end of file From 815efc2ffc9a81337b8502f3eb8bdebad49eab9c Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 20 Jun 2024 14:42:00 -0700 Subject: [PATCH 27/39] flash attention dependency --- .buildkite/test-pipeline.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8e995ec1ab4e9..9d8a4e7e3425b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -18,6 +18,7 @@ steps: commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py + - pip install vllm-flash-attn == 2.5.9 - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py @@ -42,6 +43,7 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - pip install vllm-flash-attn == 2.5.9 - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py From 901b369a9b9a11f10d84d68f080dc80147eb78d6 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 20 Jun 2024 14:58:25 -0700 Subject: [PATCH 28/39] minor --- .buildkite/test-pipeline.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9d8a4e7e3425b..5db47c67d4f46 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -18,7 +18,7 @@ steps: commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - - pip install vllm-flash-attn == 2.5.9 + - pip install vllm-flash-attn==2.5.9 - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py @@ -43,7 +43,7 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - pip install vllm-flash-attn == 2.5.9 + - pip install vllm-flash-attn==2.5.9 - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py From b2d9895cb70bfb06822c6bc694298a1ece4b2a80 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 20 Jun 2024 17:21:10 -0700 Subject: [PATCH 29/39] flash attn --- vllm/attention/backends/flashinfer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 75caa98c84d75..c2ec389e7aa36 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -2,11 +2,10 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type try: - from flash_attn import flash_attn_varlen_func + from vllm_flash_attn import flash_attn_varlen_func from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper except ImportError: - flashinfer = None flash_attn_varlen_func = None BatchDecodeWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None From df16a6bf5df80027b1cb7413eb208dd21a179649 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 20 Jun 2024 17:29:52 -0700 Subject: [PATCH 30/39] format --- vllm/attention/backends/flashinfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index c2ec389e7aa36..08e2748a5ac24 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -2,9 +2,9 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type try: - from vllm_flash_attn import flash_attn_varlen_func from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + from vllm_flash_attn import flash_attn_varlen_func except ImportError: flash_attn_varlen_func = None BatchDecodeWithPagedKVCacheWrapper = None From dc4e7efc15ed11cf45bf5c821931436f4f839182 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Fri, 21 Jun 2024 18:17:46 -0700 Subject: [PATCH 31/39] fix ci --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a1642e8203077..50b46ebbb21aa 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -757,7 +757,7 @@ def prepare_input_tensors( lora_requests = metadata_dict.pop("lora_requests") multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") - use_cuda_graph = metadata_dict.pop('use_cuda_graph') + use_cuda_graph = metadata_dict['use_cuda_graph'] is_profile_run = metadata_dict.pop('is_profile_run') if self.attn_backend.get_name() == "flashinfer": self._create_flashinfer_wrapper(metadata_dict, use_cuda_graph, From aeb0df6507be7419dd55ec4a60426bff52efa44e Mon Sep 17 00:00:00 2001 From: lilyliu Date: Tue, 25 Jun 2024 14:43:57 -0700 Subject: [PATCH 32/39] use llama3-8b in test and add warning --- .buildkite/test-pipeline.yaml | 2 +- vllm/attention/selector.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b221efc64493b..15b8621546439 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -51,7 +51,7 @@ steps: - pip install vllm-flash-attn==2.5.9 - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 618627aa10f62..562130036fccb 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -72,6 +72,9 @@ def get_attn_backend( return IpexAttnBackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") + logger.warning("Flashinfer will be stuck on llma-2-7b,", + "please avoid using Flashinfer as the", + "backend when running on llma-2-7b.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend elif backend == _Backend.PALLAS: From 8a72dcf770bdd71a42a55f5f2b8673beb88bd683 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Tue, 25 Jun 2024 15:01:56 -0700 Subject: [PATCH 33/39] fix --- vllm/attention/selector.py | 6 +++--- vllm/worker/model_runner.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 562130036fccb..7a989c6e1d3a5 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -72,9 +72,9 @@ def get_attn_backend( return IpexAttnBackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning("Flashinfer will be stuck on llma-2-7b,", - "please avoid using Flashinfer as the", - "backend when running on llma-2-7b.") + logger.warning(("Flashinfer will be stuck on llma-2-7b," + " please avoid using Flashinfer as the" + "backend when running on llma-2-7b.")) from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend elif backend == _Backend.PALLAS: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 866516d8387b5..fa31cba42c0d0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1025,6 +1025,9 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device) + indices_buffer = torch.empty(self.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.device) with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. @@ -1033,10 +1036,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: indptr_buffer = torch.empty(batch_size + 1, dtype=torch.int32, device=self.device) - indices_buffer = torch.empty( - self.cache_config.num_gpu_blocks, - dtype=torch.int32, - device=self.device) last_page_len_buffer = torch.empty(batch_size, dtype=torch.int32, device=self.device) From aaddbad192cdcef95bed36714d424219794547a5 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Tue, 25 Jun 2024 20:21:32 -0700 Subject: [PATCH 34/39] remove amd tests --- .buildkite/test-pipeline.yaml | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 15b8621546439..257fb577117c9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -18,9 +18,6 @@ steps: commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - - pip install vllm-flash-attn==2.5.9 - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl - - VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py @@ -48,10 +45,6 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - pip install vllm-flash-attn==2.5.9 - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py @@ -212,3 +205,7 @@ steps: - pytest -v -s distributed/test_custom_all_reduce.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - pip install vllm-flash-attn==2.5.9 + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py \ No newline at end of file From e61bd38d5c64a778d561b73acc58f7b8d338d056 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 27 Jun 2024 01:03:33 -0700 Subject: [PATCH 35/39] fix --- vllm/worker/model_runner.py | 79 ++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 45 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 33ed44f2e1b5e..ba25dce7e3e9e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -216,12 +216,10 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - # Flashinfer fields - self.flashinfer_prefill_wrapper: Optional[ - BatchPrefillWithPagedKVCacheWrapper] = None - self.flashinfer_decode_wrapper: Optional[ - Union[BatchDecodeWithPagedKVCacheWrapper, - CUDAGraphBatchDecodeWithPagedKVCacheWrapper]] = None + self.flashinfer_decode_workspace_buffer = None + self.flashinfer_decode_wrapper = None + self.flashinfer_prefill_workspace_buffer = None + self.flashinfer_prefill_wrapper = None def load_model(self) -> None: with CudaMemoryProfiler() as m: @@ -758,30 +756,6 @@ def _prepare_model_input_tensors( multi_modal_kwargs=multi_modal_kwargs, ) - def _create_flashinfer_wrapper(self, use_cuda_graph: bool, - is_profile_run: bool, batch_size: int): - - if use_cuda_graph and not is_profile_run: - self.flashinfer_decode_wrapper = self.graph_runners[ - batch_size].flashinfer_decode_wrapper - elif self.flashinfer_decode_wrapper is None: - self.flashinfer_decode_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_decode_wrapper = \ - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_decode_workspace_buffer, "NHD") - - if self.flashinfer_prefill_wrapper is None: - self.flashinfer_prefill_workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=self.device) - self.flashinfer_prefill_wrapper = \ - BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_prefill_workspace_buffer, "NHD") - @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. @@ -1082,21 +1056,6 @@ def make_model_input_from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) - print("==================Create tensor==================") - if self.attn_backend.get_name() == "flashinfer": - attn_metadata = tensor_dict["attn_metadata"] - assert model_input.input_tokens is not None - assert model_input.attn_metadata is not None - batch_size = model_input.input_tokens.shape[0] - is_profile_run = _is_block_tables_empty( - model_input.attn_metadata.block_tables) - self._create_flashinfer_wrapper(attn_metadata.use_cuda_graph, - is_profile_run, batch_size) - print("===============Create wrapper===================") - attn_metadata.prefill_wrapper = self.flashinfer_prefill_wrapper - attn_metadata.decode_wrapper = self.flashinfer_decode_wrapper - attn_metadata.begin_forward() - return model_input def prepare_model_input( @@ -1141,6 +1100,36 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.attn_backend.get_name() == "flashinfer": + assert model_input.attn_metadata is not None + assert model_input.input_tokens is not None + if self.flashinfer_decode_workspace_buffer is None: + self.flashinfer_decode_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_decode_wrapper = \ + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_decode_workspace_buffer, "NHD") + self.flashinfer_prefill_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_prefill_wrapper = \ + BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_prefill_workspace_buffer, "NHD") + + model_input.attn_metadata.prefill_wrapper = \ + self.flashinfer_prefill_wrapper + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + model_input.attn_metadata.decode_wrapper = self.graph_runners[ + batch_size].flashinfer_decode_wrapper + else: + model_input.attn_metadata.decode_wrapper = \ + self.flashinfer_decode_wrapper + model_input.attn_metadata.begin_forward() + # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata From b2484dfc575690643967ae53f28c6c6ebeba0a0b Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 27 Jun 2024 14:30:10 -0700 Subject: [PATCH 36/39] minor --- .buildkite/test-pipeline.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fcacbde0b2777..88c38d16c7526 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -209,7 +209,6 @@ steps: - pytest -v -s distributed/test_custom_all_reduce.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - pip install vllm-flash-attn==2.5.9 - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py \ No newline at end of file From 0f4f7966661083aa3da104fdfc20137f2db79865 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 27 Jun 2024 19:18:28 -0700 Subject: [PATCH 37/39] fix --- vllm/worker/model_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ba25dce7e3e9e..f709462c3a52b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -916,9 +916,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device) - indices_buffer = torch.empty(self.cache_config.num_gpu_blocks, - dtype=torch.int32, - device=self.device) with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. @@ -930,6 +927,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: last_page_len_buffer = torch.empty(batch_size, dtype=torch.int32, device=self.device) + indices_buffer = torch.empty( + batch_size * self.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.device) num_qo_heads = self.model_config.get_num_attention_heads( self.parallel_config) From 3dca2f011c171bdbd85643ace247c7325acd62b0 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 27 Jun 2024 19:38:08 -0700 Subject: [PATCH 38/39] change buffer init --- vllm/worker/model_runner.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f709462c3a52b..c27c05fb96acc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -916,21 +916,24 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device) + indices_buffer = torch.empty(max_batch_size * + self.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.device) + indptr_buffer = torch.empty(max_batch_size + 1, + dtype=torch.int32, + device=self.device) + last_page_len_buffer = torch.empty(max_batch_size, + dtype=torch.int32, + device=self.device) + with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): if self.attn_backend.get_name() == "flashinfer": - indptr_buffer = torch.empty(batch_size + 1, - dtype=torch.int32, - device=self.device) - last_page_len_buffer = torch.empty(batch_size, - dtype=torch.int32, - device=self.device) - indices_buffer = torch.empty( - batch_size * self.cache_config.num_gpu_blocks, - dtype=torch.int32, - device=self.device) + indptr_buffer = indptr_buffer[:batch_size + 1] + last_page_len_buffer = last_page_len_buffer[:batch_size] num_qo_heads = self.model_config.get_num_attention_heads( self.parallel_config) From 7853235c3953162520579ecbee42188f3b970f97 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 27 Jun 2024 23:40:29 -0700 Subject: [PATCH 39/39] fix ci --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 05969cfa5d65f..4af335d96d50f 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -6,7 +6,7 @@ numpy < 2.0.0 requests tqdm py-cpuinfo -transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3. +transformers >= 4.40.0,<4.42.1 # Required for StarCoder2 & Llava, Llama 3. tokenizers >= 0.19.1 # Required for Llama 3. fastapi aiohttp