-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Flashinfer for prefill & decode, with Cudagraph support for decode #4628
Changes from 7 commits
0eb1ab1
4590b46
3bfbdf7
993a4ae
b4d9dae
eb2d18e
5e3d11d
89f0e2c
72e704b
f9770ed
f1849f7
88425a3
74a8eeb
dcbbfd6
4302848
d739312
e5017e2
5ad175a
543dc3b
11b7347
b5db4be
f53d03e
6fb1b6d
e05ff79
8f685dd
0f8e7a1
cf275a1
0ab32ee
c421f1f
815efc2
901b369
b2d9895
df16a6b
64a24cb
dc4e7ef
9774919
aeb0df6
8a72dcf
aaddbad
4aa2069
e61bd38
b2484df
0f4f796
3dca2f0
7853235
8316bc3
d5348f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,18 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List, Optional, Set, Tuple, Type | ||
|
||
import flashinfer | ||
try: | ||
import flashinfer | ||
from flash_attn import flash_attn_varlen_func | ||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper | ||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper | ||
except ImportError: | ||
flashinfer = None | ||
flash_attn_varlen_func = None | ||
BatchDecodeWithPagedKVCacheWrapper = None | ||
BatchPrefillWithPagedKVCacheWrapper = None | ||
|
||
import torch | ||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper | ||
from vllm_flash_attn import flash_attn_varlen_func | ||
|
||
from vllm import _custom_ops as ops | ||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
|
@@ -62,11 +70,12 @@ class FlashInferMetadata(AttentionMetadata): | |
|
||
use_cuda_graph: bool = False | ||
|
||
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None | ||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None | ||
|
||
# Metadata for the prefill stage since we still | ||
# use flash attention for prefill. | ||
# Metadata for the prefill stage | ||
seq_start_loc: Optional[torch.Tensor] = None | ||
query_start_loc: Optional[torch.Tensor] = None | ||
block_tables: Optional[torch.Tensor] = None | ||
|
||
# Metadata for the decode stage | ||
|
@@ -109,11 +118,15 @@ def __post_init__(self): | |
f"Only {supported_head_sizes} are supported for head_dim,", | ||
f"received {self.head_dim}.") | ||
|
||
# When using flashinfer, we are also creating the FlashInferMetadata, | ||
# which will also call post_init by default, here we want to skip the | ||
# post_init if it's the prefill phase. | ||
if self.num_prefills == 0: | ||
assert self.num_decode_tokens > 0 | ||
if self.num_prefill_tokens > 0: | ||
self.prefill_wrapper = \ | ||
flashinfer.BatchPrefillWithPagedKVCacheWrapper( | ||
self.workspace_buffer, "NHD") | ||
self.prefill_wrapper.begin_forward( | ||
self.query_start_loc, self.paged_kv_indptr, | ||
self.paged_kv_indices, self.paged_kv_last_page_len, | ||
self.num_qo_heads, self.num_kv_heads, self.head_dim) | ||
else: | ||
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( | ||
self.workspace_buffer, "NHD") | ||
self.decode_wrapper.begin_forward( | ||
|
@@ -133,8 +146,9 @@ def asdict_zerocopy(self, | |
) -> Dict[str, Any]: | ||
if skip_fields is None: | ||
skip_fields = set() | ||
# We need to skip the decode_wrapper field since it cannot be | ||
# We need to skip the prefill/decode_wrapper field since it cannot be | ||
# broadcasted with nccl when TP is enabled. | ||
skip_fields.add('prefill_wrapper') | ||
skip_fields.add('decode_wrapper') | ||
return super().asdict_zerocopy(skip_fields) | ||
|
||
|
@@ -168,6 +182,7 @@ def __init__( | |
alibi_slopes: Optional[List[float]], | ||
sliding_window: Optional[int], | ||
kv_cache_dtype: str, | ||
blocksparse_params: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
self.num_heads = num_heads | ||
self.head_size = head_size | ||
|
@@ -217,10 +232,14 @@ def forward( | |
self.kv_cache_dtype, | ||
) | ||
|
||
query = query.contiguous( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we discussed this before, but what's the overhead of this call? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For llama7b on A100, the shape of query is [256, 12, 64], and this line takes ~0.037ms. |
||
) # Flashinfer requires query to be contiguous | ||
if prefill_meta := attn_metadata.prefill_metadata: | ||
# Prompt run. | ||
assert prefill_meta.block_tables is not None | ||
if kv_cache is None or prefill_meta.block_tables.numel() == 0: | ||
# We will use flash attention for prefill | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is that? Can you comment? (also is it fundamental?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only happens during the profiling phase, where the cache is initialized (not paged). We use flash attention for the profile run. |
||
# when kv_cache is not provided. | ||
# This happens when vllm runs the profiling to | ||
# determine the number of blocks. | ||
if kv_cache is None: | ||
output = flash_attn_varlen_func( | ||
q=query, | ||
k=key, | ||
|
@@ -235,13 +254,14 @@ def forward( | |
alibi_slopes=self.alibi_slopes, | ||
) | ||
else: | ||
raise NotImplementedError( | ||
"Prefix caching is not supported with flashinfer yet.") | ||
assert prefill_meta is not None | ||
assert prefill_meta.prefill_wrapper is not None | ||
output = prefill_meta.prefill_wrapper.forward(query, | ||
kv_cache, | ||
causal=True) | ||
else: | ||
assert attn_metadata.decode_metadata is not None | ||
assert attn_metadata.decode_metadata.decode_wrapper is not None | ||
query = query.contiguous( | ||
) # Flashinfer requires query to be contiguous | ||
output = attn_metadata.decode_metadata.decode_wrapper.forward( | ||
query, | ||
kv_cache, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -369,21 +369,13 @@ def _prepare_model_input( | |
if curr_sliding_window_blocks is not None: | ||
block_table = block_table[ | ||
-curr_sliding_window_blocks:] | ||
if self.attn_backend.get_name() == "flashinfer": | ||
paged_kv_indices.extend(block_table) | ||
paged_kv_indptr.append(paged_kv_indptr[-1] + | ||
len(block_table)) | ||
last_page_len = seq_data.get_len( | ||
) % self.block_size | ||
if last_page_len == 0: | ||
last_page_len = self.block_size | ||
paged_kv_last_page_len.append(last_page_len) | ||
else: | ||
# Only happens when memory profiling runs. | ||
block_table = [] | ||
else: | ||
# Prefill without chunked prefill or memory profiling. | ||
block_table = [] | ||
|
||
block_tables.append(block_table) | ||
|
||
seq_lens.append(sliding_seq_len) | ||
|
@@ -460,6 +452,15 @@ def _prepare_model_input( | |
slot = block_number * self.block_size + block_offset | ||
slot_mapping.append(slot) | ||
|
||
if self.attn_backend.get_name() == "flashinfer": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a comment saying this is to handle an incomplete block? |
||
paged_kv_indices.extend(block_table) | ||
paged_kv_indptr.append(paged_kv_indptr[-1] + | ||
len(block_table)) | ||
last_page_len = seq_len % self.block_size | ||
if last_page_len == 0: | ||
last_page_len = self.block_size | ||
paged_kv_last_page_len.append(last_page_len) | ||
|
||
batch_size = len(input_tokens) | ||
max_query_len = max(query_lens) | ||
max_prefill_seq_len = max(prefill_seq_lens, default=0) | ||
|
@@ -590,6 +591,7 @@ def _prepare_model_input( | |
head_dim=self.model_config.get_head_size(), | ||
page_size=16, | ||
seq_start_loc=seq_start_loc, | ||
query_start_loc=query_start_loc, | ||
data_type=kv_cache_dtype) | ||
else: | ||
attn_metadata = self.attn_backend.make_metadata( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this test-specific docker file? or provided to users?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We currently cannot include flashinfer to vllm's requirements.txt because it doesn't support extra index URL (https://flashinfer.ai/whl/cu121/torch2.3/). Before FlashInfer is available at PyPI, we can only let users install by themselves.