Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FlashInfer] Upgrade to 0.2.0 #11194

Merged
merged 55 commits into from
Jan 27, 2025
Merged
Changes from 10 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
269f965
[misc] remove deprecated call to `end_forward` in flashinfer backend
abmfy Dec 14, 2024
8c375a3
[flashinfer] upgrade to flashinfer 0.2.0
abmfy Dec 20, 2024
a62b854
[style] fix yapf check
abmfy Dec 20, 2024
b37ff55
[FlashInfer] Pass infered global hyperparameters to `plan`
abmfy Dec 31, 2024
72bdf7e
[FlashInfer] Cache inferred global hyperparameters
abmfy Dec 31, 2024
97dcedc
[Misc] Use `typing.Optional` for Python 3.9 compatability
abmfy Dec 31, 2024
56798c5
[Style] Fix lint errors
abmfy Dec 31, 2024
706a6f6
Merge branch 'main' into flashinfer-0.2
abmfy Jan 22, 2025
dacb6af
[FlashInfer] Cache global hyperparameters in AttentionMetadataBuilder…
abmfy Jan 22, 2025
06fa7cc
[Style] Fix ruff
abmfy Jan 22, 2025
bc480b0
[FlashInfer] Get per layer params from vllm config
abmfy Jan 23, 2025
5a70aac
[FlashInfer] Store vllm config in attention state
abmfy Jan 23, 2025
e0397e9
[CI] Update FlashInfer version
abmfy Jan 23, 2025
ec49257
format
youkaichao Jan 23, 2025
500ff5b
Merge branch 'main' into flashinfer-0.2
abmfy Jan 24, 2025
bde6807
[Misc] Add space in assert message
abmfy Jan 24, 2025
69d7c8d
[FlashInfer] Warn on models with interleaved attention
abmfy Jan 24, 2025
d4d63dc
[Test] Change backend to flash_attn for gemma in compile tests
abmfy Jan 24, 2025
6e7e933
fix inconsistent vllm config
youkaichao Jan 25, 2025
0b47067
Merge branch 'flashinfer-0.2' of github.com:abmfy/vllm-flashinfer int…
abmfy Jan 25, 2025
f6e33a7
[Test] Skip tests for Gemma2 with FlashInfer backend
abmfy Jan 25, 2025
847a4d6
[CI] Build FlashInfer from source
abmfy Jan 25, 2025
5b0fe64
[CI] Fix FlashInfer build command
abmfy Jan 25, 2025
69445cd
[CI] Fix Dockerfile
abmfy Jan 25, 2025
963aff7
[CI] Fix FlashInfer AOT build in Dockerfile
abmfy Jan 25, 2025
ae9da66
fix flashinfer docker build
youkaichao Jan 26, 2025
afa377c
Merge branch 'main' into flashinfer-0.2
youkaichao Jan 26, 2025
269e1eb
fix build command
youkaichao Jan 26, 2025
2e50ab8
move command
youkaichao Jan 26, 2025
0fe979d
unify to use setup.py
youkaichao Jan 26, 2025
3dd209c
fix cd
youkaichao Jan 26, 2025
bcd04fd
fix recursive clone
youkaichao Jan 26, 2025
bb44221
comment
youkaichao Jan 26, 2025
5ca67ae
[CI] Use precompiled FlashInfer AOT wheel
abmfy Jan 26, 2025
3c89bfb
[CI] Temporarily switch to CUDA develop image for vllm-base
abmfy Jan 26, 2025
293fdd6
Merge branch 'main' into flashinfer-0.2
abmfy Jan 26, 2025
5d8ad22
also install jit build dependency
youkaichao Jan 26, 2025
4d57ef9
[FlashInfer] Fix type of k_scale and v_scale
abmfy Jan 26, 2025
33ff07b
Merge branch 'main' into flashinfer-0.2
abmfy Jan 26, 2025
ef15977
Merge branch 'flashinfer-0.2' of github.com:abmfy/vllm-flashinfer int…
abmfy Jan 26, 2025
21efc67
fix reshape_and_cache_flash
youkaichao Jan 27, 2025
a6b6fe8
use new flashinfer
youkaichao Jan 27, 2025
1f13235
Merge branch 'main' into flashinfer-0.2
youkaichao Jan 27, 2025
f17dbc3
update v1 tests
youkaichao Jan 27, 2025
506b641
refactor test
youkaichao Jan 27, 2025
2e476a2
revert
youkaichao Jan 27, 2025
95b5493
add comments
youkaichao Jan 27, 2025
55b55d3
only check compile when loading
youkaichao Jan 27, 2025
1f80aee
test in ci?
youkaichao Jan 27, 2025
5be3783
fix one test
youkaichao Jan 27, 2025
071a68e
fix test_flashinfer_prefill_with_paged_kv
youkaichao Jan 27, 2025
0e0f57f
relax test for prefill
youkaichao Jan 27, 2025
2134e77
fix test_flashinfer_prefill_with_paged_fp8_kv
youkaichao Jan 27, 2025
8e42297
relax test for prefill
youkaichao Jan 27, 2025
b4a7992
fix test_flashinfer_decode_with_paged_fp8_kv
youkaichao Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 152 additions & 21 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
Expand All @@ -13,12 +14,15 @@
from vllm.vllm_flash_attn import flash_attn_varlen_func
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
# Avoid turning these types into variables during type checking
if not TYPE_CHECKING:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0

import torch
from torch import nn

import vllm.envs as envs
from vllm import _custom_ops as ops
Expand All @@ -30,6 +34,7 @@
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
Expand Down Expand Up @@ -99,6 +104,71 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")


@dataclass
class GlobalHyperparameters:
'''
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
'''
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float


def infer_global_hyperparameters(model: nn.Module) -> GlobalHyperparameters:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function can collect all per_layer_parameter, and only assert the results are the same.

"""
Scan all attention layers in the model and determine some hyperparameters
to use during `plan`.

Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
"""

if getattr(model, "global_hyperparameters", None) is not None:
return model.global_hyperparameters

params_inferred = False
global_window_left: Optional[int] = None
global_logits_soft_cap: Optional[float] = None
global_sm_scale: Optional[float] = None

for module in model.modules():
if isinstance(module, Attention):
impl = module.impl
assert isinstance(impl, FlashInferImpl)

# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale

if params_inferred:
MSG_PREFIX = "All attention layers must share the same "
if global_window_left != window_left:
raise ValueError(MSG_PREFIX + "`window_left`.")
if global_logits_soft_cap != logits_soft_cap:
raise ValueError(MSG_PREFIX + "`logits_soft_cap`.")
if global_sm_scale != sm_scale:
raise ValueError(MSG_PREFIX + "`sm_scale`.")

params_inferred = True
global_window_left = window_left
global_logits_soft_cap = logits_soft_cap
global_sm_scale = sm_scale

assert params_inferred
assert global_window_left is not None
assert global_sm_scale is not None

model.global_hyperparameters = GlobalHyperparameters(
global_window_left, global_logits_soft_cap, global_sm_scale)
return model.global_hyperparameters


class FlashInferState(AttentionState):

def __init__(self, runner):
Expand Down Expand Up @@ -215,6 +285,8 @@ def graph_capture_get_metadata_for_batch(
batch_size + 1,
dtype=torch.int32)

global_params = infer_global_hyperparameters(self.runner.model)

attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
Expand All @@ -237,7 +309,9 @@ def graph_capture_get_metadata_for_batch(
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
prefill_wrapper=None,
**dataclasses.asdict(global_params),
)
attn_metadata.begin_forward()
return attn_metadata

Expand Down Expand Up @@ -324,9 +398,28 @@ class FlashInferMetadata(AttentionMetadata):
data_type: torch.dtype = None
# The data type of the query
q_data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
# FlashInfer 0.2 encourages passing host tensors
device: torch.device = torch.device("cpu")
is_profile_run: bool = False

# The FlashInfer backend currently supports only models in which all layers
# share the same following hyperparameters:

# The left (inclusive) window size for the attention window, when
# set to `-1`, the window size will be set to the full length of
# the sequence. Defaults to `-1`.
window_left: int = -1
# The attention logits soft capping value (used in Gemini, Grok and
# Gemma-2, etc.), if not provided, will be set to `0`. If greater
# than 0, the logits will be capped according to formula:
# $$\texttt{logits\_soft\_cap} \times
# \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$,
# where $x$ is the input logits.
logits_soft_cap: Optional[float] = None
# The scale used in softmax, if not provided, will be set to
# `1.0 / sqrt(head_dim)`.
sm_scale: Optional[float] = None

def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
Expand Down Expand Up @@ -362,14 +455,21 @@ def begin_forward(self):
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.prefill_wrapper.plan(
self.query_start_loc,
self.paged_kv_indptr[:self.num_prefills + 1],
self.paged_kv_indices,
self.paged_kv_last_page_len[:self.num_prefills],
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.data_type)
if self.num_decode_tokens > 0:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
Expand All @@ -385,8 +485,7 @@ def begin_forward(self):
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)

assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward(
self.decode_wrapper.plan(
self.paged_kv_indptr[self.num_prefills:],
self.paged_kv_indices,
self.paged_kv_last_page_len[self.num_prefills:],
Expand All @@ -396,8 +495,11 @@ def begin_forward(self):
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
sm_scale=self.sm_scale,
# kv-cache data type.
data_type=self.data_type,
kv_data_type=self.data_type,
# query data type.
q_data_type=self.q_data_type)

Expand Down Expand Up @@ -495,6 +597,8 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remember the vllm_config here by calling get_current_vllm_config()


self.global_hyperparameters: Optional[GlobalHyperparameters] = None

def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
Expand Down Expand Up @@ -527,6 +631,20 @@ def prepare(self):
self.total_blocks = 0
self.is_profile_run: bool = False

if self.global_hyperparameters is None:
# Infer global hyperparameters, since currently we only support
# models in which all layers share the same values for the
# following hyperparameters:
# - `window_left`
# - `logits_soft_cap`
# - `sm_scale`
model = self.runner.model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm_config.compilation_config.static_forward_context is a dict of layer prefix to attention layer. you can collect sliding window, etc. from there. no need to iterate over model's submodule.

inferred_params = infer_global_hyperparameters(model)
self.global_hyperparameters = inferred_params
self.window_left = inferred_params.window_left
self.logits_soft_cap = inferred_params.logits_soft_cap
self.sm_scale = inferred_params.sm_scale

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
Expand Down Expand Up @@ -754,7 +872,11 @@ def build(self, seq_lens: List[int], query_lens: List[int],
data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run)
is_profile_run=self.is_profile_run,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
sm_scale=self.sm_scale,
)


class FlashInferImpl(AttentionImpl):
Expand Down Expand Up @@ -883,25 +1005,34 @@ def forward(
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(

assert prefill_meta.prefill_wrapper._causal
assert prefill_meta.prefill_wrapper._window_left == window_left
assert prefill_meta.prefill_wrapper._logits_soft_cap == (
logits_soft_cap or 0.0)
assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale

prefill_output = prefill_meta.prefill_wrapper.run(
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)
)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
assert decode_meta.decode_wrapper is not None
decode_output = decode_meta.decode_wrapper.forward(

assert decode_meta.decode_wrapper._window_left == window_left
assert decode_meta.decode_wrapper._logits_soft_cap == (
logits_soft_cap or 0.0)
assert decode_meta.decode_wrapper._sm_scale == softmax_scale

decode_output = decode_meta.decode_wrapper.run(
decode_query,
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)
)

if prefill_output is None and decode_output is not None:
# Decode only batch.
Expand Down
Loading