Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FlashInfer] Upgrade to 0.2.0 #11194

Merged
merged 55 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
269f965
[misc] remove deprecated call to `end_forward` in flashinfer backend
abmfy Dec 14, 2024
8c375a3
[flashinfer] upgrade to flashinfer 0.2.0
abmfy Dec 20, 2024
a62b854
[style] fix yapf check
abmfy Dec 20, 2024
b37ff55
[FlashInfer] Pass infered global hyperparameters to `plan`
abmfy Dec 31, 2024
72bdf7e
[FlashInfer] Cache inferred global hyperparameters
abmfy Dec 31, 2024
97dcedc
[Misc] Use `typing.Optional` for Python 3.9 compatability
abmfy Dec 31, 2024
56798c5
[Style] Fix lint errors
abmfy Dec 31, 2024
706a6f6
Merge branch 'main' into flashinfer-0.2
abmfy Jan 22, 2025
dacb6af
[FlashInfer] Cache global hyperparameters in AttentionMetadataBuilder…
abmfy Jan 22, 2025
06fa7cc
[Style] Fix ruff
abmfy Jan 22, 2025
bc480b0
[FlashInfer] Get per layer params from vllm config
abmfy Jan 23, 2025
5a70aac
[FlashInfer] Store vllm config in attention state
abmfy Jan 23, 2025
e0397e9
[CI] Update FlashInfer version
abmfy Jan 23, 2025
ec49257
format
youkaichao Jan 23, 2025
500ff5b
Merge branch 'main' into flashinfer-0.2
abmfy Jan 24, 2025
bde6807
[Misc] Add space in assert message
abmfy Jan 24, 2025
69d7c8d
[FlashInfer] Warn on models with interleaved attention
abmfy Jan 24, 2025
d4d63dc
[Test] Change backend to flash_attn for gemma in compile tests
abmfy Jan 24, 2025
6e7e933
fix inconsistent vllm config
youkaichao Jan 25, 2025
0b47067
Merge branch 'flashinfer-0.2' of github.com:abmfy/vllm-flashinfer int…
abmfy Jan 25, 2025
f6e33a7
[Test] Skip tests for Gemma2 with FlashInfer backend
abmfy Jan 25, 2025
847a4d6
[CI] Build FlashInfer from source
abmfy Jan 25, 2025
5b0fe64
[CI] Fix FlashInfer build command
abmfy Jan 25, 2025
69445cd
[CI] Fix Dockerfile
abmfy Jan 25, 2025
963aff7
[CI] Fix FlashInfer AOT build in Dockerfile
abmfy Jan 25, 2025
ae9da66
fix flashinfer docker build
youkaichao Jan 26, 2025
afa377c
Merge branch 'main' into flashinfer-0.2
youkaichao Jan 26, 2025
269e1eb
fix build command
youkaichao Jan 26, 2025
2e50ab8
move command
youkaichao Jan 26, 2025
0fe979d
unify to use setup.py
youkaichao Jan 26, 2025
3dd209c
fix cd
youkaichao Jan 26, 2025
bcd04fd
fix recursive clone
youkaichao Jan 26, 2025
bb44221
comment
youkaichao Jan 26, 2025
5ca67ae
[CI] Use precompiled FlashInfer AOT wheel
abmfy Jan 26, 2025
3c89bfb
[CI] Temporarily switch to CUDA develop image for vllm-base
abmfy Jan 26, 2025
293fdd6
Merge branch 'main' into flashinfer-0.2
abmfy Jan 26, 2025
5d8ad22
also install jit build dependency
youkaichao Jan 26, 2025
4d57ef9
[FlashInfer] Fix type of k_scale and v_scale
abmfy Jan 26, 2025
33ff07b
Merge branch 'main' into flashinfer-0.2
abmfy Jan 26, 2025
ef15977
Merge branch 'flashinfer-0.2' of github.com:abmfy/vllm-flashinfer int…
abmfy Jan 26, 2025
21efc67
fix reshape_and_cache_flash
youkaichao Jan 27, 2025
a6b6fe8
use new flashinfer
youkaichao Jan 27, 2025
1f13235
Merge branch 'main' into flashinfer-0.2
youkaichao Jan 27, 2025
f17dbc3
update v1 tests
youkaichao Jan 27, 2025
506b641
refactor test
youkaichao Jan 27, 2025
2e476a2
revert
youkaichao Jan 27, 2025
95b5493
add comments
youkaichao Jan 27, 2025
55b55d3
only check compile when loading
youkaichao Jan 27, 2025
1f80aee
test in ci?
youkaichao Jan 27, 2025
5be3783
fix one test
youkaichao Jan 27, 2025
071a68e
fix test_flashinfer_prefill_with_paged_kv
youkaichao Jan 27, 2025
0e0f57f
relax test for prefill
youkaichao Jan 27, 2025
2134e77
fix test_flashinfer_prefill_with_paged_fp8_kv
youkaichao Jan 27, 2025
8e42297
relax test for prefill
youkaichao Jan 27, 2025
b4a7992
fix test_flashinfer_decode_with_paged_fp8_kv
youkaichao Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 56 additions & 48 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0

import torch
from torch import nn

import vllm.envs as envs
from vllm import _custom_ops as ops
Expand All @@ -36,6 +35,7 @@
is_block_tables_empty)
from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)

Expand Down Expand Up @@ -105,68 +105,69 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:


@dataclass
class GlobalHyperparameters:
'''
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
'''
"""

window_left: int
logits_soft_cap: Optional[float]
sm_scale: float


def infer_global_hyperparameters(model: nn.Module) -> GlobalHyperparameters:
def get_per_layer_parameters(
vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]:
"""
Scan all attention layers in the model and determine some hyperparameters
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""

layers = vllm_config.compilation_config.static_forward_context
per_layer_params: Dict[str, PerLayerParameters] = {}

for key, layer in layers.items():
assert isinstance(layer, Attention)

impl = layer.impl
assert isinstance(impl, FlashInferImpl)

# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale

per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)

return per_layer_params


def infer_global_hyperparameters(
per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
"""

if getattr(model, "global_hyperparameters", None) is not None:
return model.global_hyperparameters

params_inferred = False
global_window_left: Optional[int] = None
global_logits_soft_cap: Optional[float] = None
global_sm_scale: Optional[float] = None

for module in model.modules():
if isinstance(module, Attention):
impl = module.impl
assert isinstance(impl, FlashInferImpl)

# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale

if params_inferred:
MSG_PREFIX = "All attention layers must share the same "
if global_window_left != window_left:
raise ValueError(MSG_PREFIX + "`window_left`.")
if global_logits_soft_cap != logits_soft_cap:
raise ValueError(MSG_PREFIX + "`logits_soft_cap`.")
if global_sm_scale != sm_scale:
raise ValueError(MSG_PREFIX + "`sm_scale`.")
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""

params_inferred = True
global_window_left = window_left
global_logits_soft_cap = logits_soft_cap
global_sm_scale = sm_scale
assert len(per_layer_params) > 0, "No attention layers found in the model."

assert params_inferred
assert global_window_left is not None
assert global_sm_scale is not None
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all"
"layers share the same values for the following hyperparameters:"
"`window_left`, `logits_soft_cap`, `sm_scale`.")

model.global_hyperparameters = GlobalHyperparameters(
global_window_left, global_logits_soft_cap, global_sm_scale)
return model.global_hyperparameters
return global_params


class FlashInferState(AttentionState):
Expand All @@ -178,6 +179,9 @@ def __init__(self, runner):
self._decode_wrapper = None
self._prefill_wrapper = None

# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remember the vllm_config here?


def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
Expand Down Expand Up @@ -285,7 +289,8 @@ def graph_capture_get_metadata_for_batch(
batch_size + 1,
dtype=torch.int32)

global_params = infer_global_hyperparameters(self.runner.model)
global_params = infer_global_hyperparameters(
get_per_layer_parameters(get_current_vllm_config()))

attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
Expand Down Expand Up @@ -597,7 +602,10 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remember the vllm_config here by calling get_current_vllm_config()


self.global_hyperparameters: Optional[GlobalHyperparameters] = None
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None

self.vllm_config = get_current_vllm_config()

def prepare(self):
self.slot_mapping: List[int] = []
Expand Down Expand Up @@ -638,8 +646,8 @@ def prepare(self):
# - `window_left`
# - `logits_soft_cap`
# - `sm_scale`
model = self.runner.model
inferred_params = infer_global_hyperparameters(model)
inferred_params = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
self.global_hyperparameters = inferred_params
self.window_left = inferred_params.window_left
self.logits_soft_cap = inferred_params.logits_soft_cap
Expand Down
16 changes: 10 additions & 6 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_kv_transfer_group, get_pp_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -1498,11 +1498,15 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
) if get_tensor_model_parallel_rank() == 0 else
self.vllm_config.compilation_config.capture_sizes)
for batch_size in capture_sizes:
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,
is_encoder_decoder_model=self.model_config.
is_encoder_decoder))
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during
# worker initialization
attn_metadata = (self.attn_state.
graph_capture_get_metadata_for_batch(
batch_size,
is_encoder_decoder_model=self.
model_config.is_encoder_decoder,
))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we don't need this change.


if self.lora_config:
lora_mapping = LoRAMapping(
Expand Down
8 changes: 5 additions & 3 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.nn as nn

from vllm.config import ObservabilityConfig, VllmConfig
from vllm.config import ObservabilityConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -546,8 +546,10 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
self.worker = worker_class(**kwargs)
assert self.worker is not None
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
assert self.worker is not None

def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try:
Expand Down
Loading