-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FlashInfer] Upgrade to 0.2.0 #11194
Merged
+257
−75
Merged
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit
Hold shift + click to select a range
269f965
[misc] remove deprecated call to `end_forward` in flashinfer backend
abmfy 8c375a3
[flashinfer] upgrade to flashinfer 0.2.0
abmfy a62b854
[style] fix yapf check
abmfy b37ff55
[FlashInfer] Pass infered global hyperparameters to `plan`
abmfy 72bdf7e
[FlashInfer] Cache inferred global hyperparameters
abmfy 97dcedc
[Misc] Use `typing.Optional` for Python 3.9 compatability
abmfy 56798c5
[Style] Fix lint errors
abmfy 706a6f6
Merge branch 'main' into flashinfer-0.2
abmfy dacb6af
[FlashInfer] Cache global hyperparameters in AttentionMetadataBuilder…
abmfy 06fa7cc
[Style] Fix ruff
abmfy bc480b0
[FlashInfer] Get per layer params from vllm config
abmfy 5a70aac
[FlashInfer] Store vllm config in attention state
abmfy e0397e9
[CI] Update FlashInfer version
abmfy ec49257
format
youkaichao 500ff5b
Merge branch 'main' into flashinfer-0.2
abmfy bde6807
[Misc] Add space in assert message
abmfy 69d7c8d
[FlashInfer] Warn on models with interleaved attention
abmfy d4d63dc
[Test] Change backend to flash_attn for gemma in compile tests
abmfy 6e7e933
fix inconsistent vllm config
youkaichao 0b47067
Merge branch 'flashinfer-0.2' of github.com:abmfy/vllm-flashinfer int…
abmfy f6e33a7
[Test] Skip tests for Gemma2 with FlashInfer backend
abmfy 847a4d6
[CI] Build FlashInfer from source
abmfy 5b0fe64
[CI] Fix FlashInfer build command
abmfy 69445cd
[CI] Fix Dockerfile
abmfy 963aff7
[CI] Fix FlashInfer AOT build in Dockerfile
abmfy ae9da66
fix flashinfer docker build
youkaichao afa377c
Merge branch 'main' into flashinfer-0.2
youkaichao 269e1eb
fix build command
youkaichao 2e50ab8
move command
youkaichao 0fe979d
unify to use setup.py
youkaichao 3dd209c
fix cd
youkaichao bcd04fd
fix recursive clone
youkaichao bb44221
comment
youkaichao 5ca67ae
[CI] Use precompiled FlashInfer AOT wheel
abmfy 3c89bfb
[CI] Temporarily switch to CUDA develop image for vllm-base
abmfy 293fdd6
Merge branch 'main' into flashinfer-0.2
abmfy 5d8ad22
also install jit build dependency
youkaichao 4d57ef9
[FlashInfer] Fix type of k_scale and v_scale
abmfy 33ff07b
Merge branch 'main' into flashinfer-0.2
abmfy ef15977
Merge branch 'flashinfer-0.2' of github.com:abmfy/vllm-flashinfer int…
abmfy 21efc67
fix reshape_and_cache_flash
youkaichao a6b6fe8
use new flashinfer
youkaichao 1f13235
Merge branch 'main' into flashinfer-0.2
youkaichao f17dbc3
update v1 tests
youkaichao 506b641
refactor test
youkaichao 2e476a2
revert
youkaichao 95b5493
add comments
youkaichao 55b55d3
only check compile when loading
youkaichao 1f80aee
test in ci?
youkaichao 5be3783
fix one test
youkaichao 071a68e
fix test_flashinfer_prefill_with_paged_kv
youkaichao 0e0f57f
relax test for prefill
youkaichao 2134e77
fix test_flashinfer_prefill_with_paged_fp8_kv
youkaichao 8e42297
relax test for prefill
youkaichao b4a7992
fix test_flashinfer_decode_with_paged_fp8_kv
youkaichao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,6 @@ | |
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 | ||
|
||
import torch | ||
from torch import nn | ||
|
||
import vllm.envs as envs | ||
from vllm import _custom_ops as ops | ||
|
@@ -36,6 +35,7 @@ | |
is_block_tables_empty) | ||
from vllm.attention.layer import Attention | ||
from vllm.attention.ops.paged_attn import PagedAttention | ||
from vllm.config import VllmConfig, get_current_vllm_config | ||
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, | ||
make_tensor_with_pad) | ||
|
||
|
@@ -105,68 +105,69 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: | |
|
||
|
||
@dataclass | ||
class GlobalHyperparameters: | ||
''' | ||
class PerLayerParameters: | ||
""" | ||
Currently, FlashInfer backend only support models in which all layers share | ||
the same values for the following hyperparameters. | ||
''' | ||
""" | ||
|
||
window_left: int | ||
logits_soft_cap: Optional[float] | ||
sm_scale: float | ||
|
||
|
||
def infer_global_hyperparameters(model: nn.Module) -> GlobalHyperparameters: | ||
def get_per_layer_parameters( | ||
vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]: | ||
""" | ||
Scan all attention layers in the model and determine some hyperparameters | ||
Scan all attention layers and determine some hyperparameters | ||
to use during `plan`. | ||
""" | ||
|
||
layers = vllm_config.compilation_config.static_forward_context | ||
per_layer_params: Dict[str, PerLayerParameters] = {} | ||
|
||
for key, layer in layers.items(): | ||
assert isinstance(layer, Attention) | ||
|
||
impl = layer.impl | ||
assert isinstance(impl, FlashInferImpl) | ||
|
||
# Infer hyperparameters from the attention layer | ||
window_size = impl.sliding_window | ||
window_left = window_size[0] if window_size is not None else -1 | ||
logits_soft_cap = impl.logits_soft_cap | ||
sm_scale = impl.scale | ||
|
||
per_layer_params[key] = PerLayerParameters(window_left, | ||
logits_soft_cap, sm_scale) | ||
|
||
return per_layer_params | ||
|
||
|
||
def infer_global_hyperparameters( | ||
per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters: | ||
""" | ||
Currently, FlashInfer backend only support models in which all layers share | ||
the same values for the following hyperparameters: | ||
- `window_left` | ||
- `logits_soft_cap` | ||
- `sm_scale` | ||
""" | ||
|
||
if getattr(model, "global_hyperparameters", None) is not None: | ||
return model.global_hyperparameters | ||
|
||
params_inferred = False | ||
global_window_left: Optional[int] = None | ||
global_logits_soft_cap: Optional[float] = None | ||
global_sm_scale: Optional[float] = None | ||
|
||
for module in model.modules(): | ||
if isinstance(module, Attention): | ||
impl = module.impl | ||
assert isinstance(impl, FlashInferImpl) | ||
|
||
# Infer hyperparameters from the attention layer | ||
window_size = impl.sliding_window | ||
window_left = window_size[0] if window_size is not None else -1 | ||
logits_soft_cap = impl.logits_soft_cap | ||
sm_scale = impl.scale | ||
|
||
if params_inferred: | ||
MSG_PREFIX = "All attention layers must share the same " | ||
if global_window_left != window_left: | ||
raise ValueError(MSG_PREFIX + "`window_left`.") | ||
if global_logits_soft_cap != logits_soft_cap: | ||
raise ValueError(MSG_PREFIX + "`logits_soft_cap`.") | ||
if global_sm_scale != sm_scale: | ||
raise ValueError(MSG_PREFIX + "`sm_scale`.") | ||
So this function asserts that all layers share the same values for these | ||
hyperparameters and returns the global values. | ||
""" | ||
|
||
params_inferred = True | ||
global_window_left = window_left | ||
global_logits_soft_cap = logits_soft_cap | ||
global_sm_scale = sm_scale | ||
assert len(per_layer_params) > 0, "No attention layers found in the model." | ||
|
||
assert params_inferred | ||
assert global_window_left is not None | ||
assert global_sm_scale is not None | ||
param_sets = list(per_layer_params.values()) | ||
global_params = param_sets[0] | ||
for params in param_sets: | ||
assert params == global_params, ( | ||
"FlashInfer backend currently only supports models in which all" | ||
"layers share the same values for the following hyperparameters:" | ||
"`window_left`, `logits_soft_cap`, `sm_scale`.") | ||
|
||
model.global_hyperparameters = GlobalHyperparameters( | ||
global_window_left, global_logits_soft_cap, global_sm_scale) | ||
return model.global_hyperparameters | ||
return global_params | ||
|
||
|
||
class FlashInferState(AttentionState): | ||
|
@@ -178,6 +179,9 @@ def __init__(self, runner): | |
self._decode_wrapper = None | ||
self._prefill_wrapper = None | ||
|
||
# Global hyperparameters shared by all attention layers | ||
self.global_hyperparameters: Optional[PerLayerParameters] = None | ||
|
||
def _get_workspace_buffer(self): | ||
if self._workspace_buffer is None: | ||
self._workspace_buffer = torch.empty( | ||
|
@@ -285,7 +289,8 @@ def graph_capture_get_metadata_for_batch( | |
batch_size + 1, | ||
dtype=torch.int32) | ||
|
||
global_params = infer_global_hyperparameters(self.runner.model) | ||
global_params = infer_global_hyperparameters( | ||
get_per_layer_parameters(get_current_vllm_config())) | ||
|
||
attn_metadata = self.runner.attn_backend.make_metadata( | ||
num_prefills=0, | ||
|
@@ -597,7 +602,10 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): | |
self.sliding_window = input_builder.sliding_window | ||
self.block_size = input_builder.block_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can remember the vllm_config here by calling |
||
|
||
self.global_hyperparameters: Optional[GlobalHyperparameters] = None | ||
# Global hyperparameters shared by all attention layers | ||
self.global_hyperparameters: Optional[PerLayerParameters] = None | ||
|
||
self.vllm_config = get_current_vllm_config() | ||
|
||
def prepare(self): | ||
self.slot_mapping: List[int] = [] | ||
|
@@ -638,8 +646,8 @@ def prepare(self): | |
# - `window_left` | ||
# - `logits_soft_cap` | ||
# - `sm_scale` | ||
model = self.runner.model | ||
inferred_params = infer_global_hyperparameters(model) | ||
inferred_params = infer_global_hyperparameters( | ||
get_per_layer_parameters(self.vllm_config)) | ||
self.global_hyperparameters = inferred_params | ||
self.window_left = inferred_params.window_left | ||
self.logits_soft_cap = inferred_params.logits_soft_cap | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
from vllm.attention import AttentionMetadata, get_attn_backend | ||
from vllm.attention.backends.abstract import AttentionState | ||
from vllm.attention.backends.utils import CommonAttentionState | ||
from vllm.config import CompilationLevel, VllmConfig | ||
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config | ||
from vllm.core.scheduler import SchedulerOutputs | ||
from vllm.distributed import get_kv_transfer_group, get_pp_group | ||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, | ||
|
@@ -1498,11 +1498,15 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: | |
) if get_tensor_model_parallel_rank() == 0 else | ||
self.vllm_config.compilation_config.capture_sizes) | ||
for batch_size in capture_sizes: | ||
attn_metadata = ( | ||
self.attn_state.graph_capture_get_metadata_for_batch( | ||
batch_size, | ||
is_encoder_decoder_model=self.model_config. | ||
is_encoder_decoder)) | ||
with set_current_vllm_config(self.vllm_config): | ||
# To make vLLM config available during | ||
# worker initialization | ||
attn_metadata = (self.attn_state. | ||
graph_capture_get_metadata_for_batch( | ||
batch_size, | ||
is_encoder_decoder_model=self. | ||
model_config.is_encoder_decoder, | ||
)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. then we don't need this change. |
||
|
||
if self.lora_config: | ||
lora_mapping = LoRAMapping( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remember the
vllm_config
here?