diff --git a/csrc/cache.h b/csrc/cache.h index eedad9fafa3c0..55ed30bd8ce48 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -28,6 +28,11 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale); +void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, + torch::Tensor& kv_cache, torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + torch::Tensor& scale); + // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 21a0aec0ececc..23a46b6ed8ad8 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -245,6 +245,51 @@ __global__ void reshape_and_cache_flash_kernel( } } } + +template +__global__ void concat_and_cache_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, + int src_stride, int dst_stride, int size, int offset) { + for (int i = threadIdx.x; i < size; i += blockDim.x) { + const int64_t src_idx = token_idx * src_stride + i; + const int64_t dst_idx = block_idx * block_stride + + block_offset * (kv_lora_rank + pe_dim) + i + + offset; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = + fp8::scaled_convert(src[src_idx], *scale); + } + } + }; + + copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); + copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); +} + } // namespace vllm // KV_T is the stored data type of kv-cache. @@ -343,6 +388,56 @@ void reshape_and_cache_flash( CALL_RESHAPE_AND_CACHE_FLASH); } +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::concat_and_cache_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, kv_c_stride, \ + k_pe_stride, kv_lora_rank, pe_dim, block_size, \ + reinterpret_cast(scale.data_ptr())); + +void concat_and_cache_mla( + torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] + torch::Tensor& k_pe, // [num_tokens, pe_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // pe_dim)] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, torch::Tensor& scale) { + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); + int kv_lora_rank = kv_c.size(1); + int pe_dim = k_pe.size(1); + int block_size = kv_cache.size(1); + + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + + int kv_c_stride = kv_c.stride(0); + int k_pe_stride = k_pe.stride(0); + int block_stride = kv_cache.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_MLA); +} + namespace vllm { template diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ec63170d511f0..1846d9ac29943 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -463,6 +463,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); + // Concat kv_c and k_pe and cache them. + cache_ops.def( + "concat_and_cache_mla(Tensor kv_c, Tensor k_pe," + " Tensor! kv_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " Tensor scale) -> ()"); + cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla); + // Convert the key and value cache to fp8 data type. cache_ops.def( "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " diff --git a/tests/kernels/test_triton_decode_attention.py b/tests/kernels/test_triton_decode_attention.py new file mode 100644 index 0000000000000..14f5a3b770b69 --- /dev/null +++ b/tests/kernels/test_triton_decode_attention.py @@ -0,0 +1,89 @@ +import pytest +import torch + +from vllm.attention.ops.triton_decode_attention import decode_attention_fwd + + +def cdiv(a, b): + return (a + b - 1) // b + + +@pytest.mark.parametrize("B", [3, 5]) +@pytest.mark.parametrize("L", [1027, 1025]) +@pytest.mark.parametrize("H_Q", [32]) +@pytest.mark.parametrize("H_KV", [32, 8]) +@pytest.mark.parametrize("D_QK", [128, 192, 576]) +@pytest.mark.parametrize("D_V", [128, 512]) +@pytest.mark.parametrize("CACHE_SIZE", [16384]) +@pytest.mark.parametrize("PAGE_SIZE", [1, 16]) +def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): + assert CACHE_SIZE % PAGE_SIZE == 0 + dtype = torch.bfloat16 + seq_len = L # This represents the number of tokens already in the sequence + sm_scale = 1.0 / (D_QK**0.5) + num_kv_splits = 8 + + num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) + req_to_page = torch.randint(0, + CACHE_SIZE // PAGE_SIZE, + (B, num_pages_per_batch, 1), + device="cuda") + req_to_token = req_to_page * PAGE_SIZE + req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) + req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view( + 1, 1, -1) + req_to_token = req_to_token.view(B, -1) + req_to_token = req_to_token[:, :seq_len].contiguous() + + # q represents the new token being generated, one per batch + q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda") + + # k_buffer and v_buffer represent all previous tokens + # Page size is 1. + k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda") + v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda") + + # o will have the same shape as q + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + + b_seq_len = torch.full((B, ), seq_len, device="cuda") + + attn_logits = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, + device="cuda", + ) + + # Call the original implementation. + decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + ) + + # Page size can be larger than 1. + k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK) + v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) + + o1 = torch.zeros_like(o) + + decode_attention_fwd( + q, + k_buffer, + v_buffer, + o1, + req_to_page, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + PAGE_SIZE, + ) + + assert torch.allclose(o, o1) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index 84721d5971ccf..d7c6bdd707eb7 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -5,5 +5,5 @@ class DummyPlatform(CudaPlatform): device_name = "DummyDevice" def get_attn_backend_cls(self, backend_name, head_size, dtype, - kv_cache_dtype, block_size, use_v1): + kv_cache_dtype, block_size, use_v1, use_mla): return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 272206d4502e9..1b797074096ed 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -20,7 +20,7 @@ compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main -compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main +#compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90 compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90 awq, casperhansen/mixtral-instruct-awq, main diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh index 693128640e07d..8a899bc154f35 100755 --- a/tests/weight_loading/run_model_weight_loading_test.sh +++ b/tests/weight_loading/run_model_weight_loading_test.sh @@ -3,7 +3,7 @@ SUCCESS=0 while getopts "c:" OPT; do case ${OPT} in - c ) + c ) CONFIG="$OPTARG" ;; \? ) @@ -18,9 +18,14 @@ IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG" for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" do + if [[ $MODEL_CONFIG == \#* ]]; then + echo "=== SKIPPING MODEL: $MODEL_CONFIG ===" + continue + fi + LOCAL_SUCCESS=0 IFS=', ' read -r -a array <<< "$MODEL_CONFIG" - + echo "=== RUNNING MODEL: $MODEL_CONFIG ===" export QUANTIZATION=${array[0]} diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4a11b0206e003..fd94134de0219 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1002,6 +1002,19 @@ def reshape_and_cache_flash( v_scale) +def concat_and_cache_mla( + kv_c: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + scale: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, + slot_mapping, kv_cache_dtype, + scale) + + def copy_blocks(key_caches: List[torch.Tensor], value_caches: List[torch.Tensor], block_mapping: torch.Tensor) -> None: diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 8027a52b82ffc..b9425f659f7c0 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -276,3 +276,19 @@ def forward( output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError + + +class MLAAttentionImpl(AttentionImpl[T], Generic[T]): + + @abstractmethod + def forward( + self, + layer: AttentionLayer, + hidden_states_or_cq: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/backends/mla/__init__.py b/vllm/attention/backends/mla/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py new file mode 100644 index 0000000000000..c6c8a6034e20f --- /dev/null +++ b/vllm/attention/backends/mla/utils.py @@ -0,0 +1,365 @@ +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Generic, List, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.attention.backends.abstract import (AttentionLayer, + AttentionMetadata, + MLAAttentionImpl, T) +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.vllm_flash_attn import flash_attn_varlen_func + + +@dataclass +class MLACommonMetadata(AttentionMetadata): + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + + +class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): + """ + Common class for implementing repeated parts + + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + + Deepseek's MLA attention works the following way: + * Use a single latent vector to represent the entire KV cache. + * The attention "simulates" a multi-head attention, while the compute is + similar to multi-query attention. + * The dataflow is as follows, + + * B: batch/sequence length + * H: hidden size + * N: number of attention heads + * Lq: latent dimension for Q + * Lkv: latent dimension for K/V + * P: nope dimension, P+R is the actual head_dim in common attention. + * R: rope dimension, this slide of the head_dim goes through rope. + * V: V head dim. + * kv_c: latent/compressed KV + * q_c: latent/compressed Q + + # + # Outside the MLA attention backend + # + + 1. The hidden states (B, H) are projected down into cq (B, Lq) and + kv_c_k_pe (B, Lkv+R). + 2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq + and kv_c are normalized. + + # + # Inside the MLA attention backend + # + + * if prefill: + + 3. The q_c is then projected up into the multi-head version. + * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope + (B, N, P) and q_pe (B, N, R). + 4. q_pe, k_pe are then passed through rotary embeddings. + 5. kv_c and k_pe are concatenated and inserted into the cache + 6. The kv_c is then projected up into the multi-head version. + * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope + dimensions for K and V, which is split into k_nope (B, N, P) + and v (B, N, V). + 7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from + q_nope, q_pe, k_nope, k_pe. + 8. Attention is computued with q, k, v. + 9. The attention computation returns (B, N, V), which is projected back + to (B, H) using out projection. + + * if decode: + + 3. Here's the change, we do not perform up the full up projection for + q_c, and there is no up projection at all for kv_c. This is + achieved by the technique of "weight absorption". The paper says + "Fortunately, due to the associative law of matrix multiplication, + we can absorb WUK into WUQ, and WUV into WO" + * The q up projection turns (B, Lq) into (B, N, (P+R)), we split it + into W_UQ (Lq, N, P) and W_QR (Lq, N, R). + * The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split + it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V). + * The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H). + * We can precompute the product of W_UQ and W_UK into + W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in + attention. + * We can precompute the product of W_UV and W_O into + W_UV_O (N, Lkv, H), which is possible due to V@O as the + "epilogue" of attention + 4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent. + 5. q_pe, k_pe are then passed through rotary embeddings. + 6. kv_c and k_pe are concatenated and inserted into the cache + 7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape + (B, N, Lkv). + 8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe, + kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a. + 9. The attention is computed with q, k, v. Note that we just performed + a MQA attention with (LKv+R) as our head dim. + 10. The KV cache is updated using the new entries k (B, N, (Lkv+R)), + which included the v and rope values. + 11. The attention computation returns (B, N, Lkv), which is projected + back to (B, H) using W_UV_O. + + From @tsu-bin's calculation, we only want to use the absorption technique + for decode. The prefill algorithm should still use the up-projected MHA + for less flops and memory usage. + + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + rotary_emb: RotaryEmbedding, + # q_proj should be q_b_proj if q_lora_rank is not None, but from an + # attention backend perspective we rely on the layer to pass in the + # correct matrix + q_proj: ColumnParallelLinear, + kv_b_proj: ColumnParallelLinear, + o_proj: RowParallelLinear, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + + self.rotary_emb = rotary_emb + self.q_proj = q_proj + self.kv_b_proj = kv_b_proj + self.o_proj = o_proj + + def _v_up_proj_and_o_proj(self, x): + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + return self.o_proj_absorbed( + x.reshape(-1, self.num_heads * self.kv_lora_rank))[0] + else: + x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) + return self.o_proj(x.reshape(-1, + self.num_heads * self.v_head_dim))[0] + + def _q_proj_and_k_up_proj(self, x): + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + return torch.matmul(x, self.W_Q_UK)\ + .view(-1, self.num_heads, self.kv_lora_rank) + else: + x = torch.matmul(x, self.W_Q)\ + .view(-1, self.num_heads, self.qk_nope_head_dim) + return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ + .view(-1, self.num_heads, self.kv_lora_rank) + + def process_weights_after_loading(self): + kv_b_proj_weight = self.kv_b_proj.weight.T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + q_proj = self.q_proj.weight.T\ + .view(-1, self.num_heads, self.qk_head_dim) + + # can be W_Q or W_UQ depending q_lora_rank, the former if + # q_lora_rank is None, the latter otherwise. From the Attention backend + # perspective though we call these both W_Q and rely on the layer + # to pass in the correct matrix + W_Q = q_proj[..., :self.qk_nope_head_dim] + self.W_QR = q_proj[..., self.qk_nope_head_dim:]\ + .flatten(start_dim=1).contiguous() + + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + # + # Perform matrix-absorption following + # https://github.com/flashinfer-ai/flashinfer/pull/551 + # for decode, as a result we end up with absorbed weights for decode + # and another copy of raw weights for prefill. + # + self.W_UK, self.W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK + # depending q_lora_rank, the former if q_lora_rank is None, the + # latter otherwise + # basically if q_lora_rank is none we are absorbing into q_proj + # instead of UQ + self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ + .flatten(start_dim=1).contiguous() + + W_O = self.o_proj.weight\ + .view(-1, self.num_heads, self.v_head_dim) + self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ + .flatten(start_dim=0, end_dim=1).contiguous() + + tp_size = get_tensor_model_parallel_world_size() + self.o_proj_absorbed = RowParallelLinear( + self.W_UV_O.shape[0] * tp_size, + self.W_UV_O.shape[1], + bias=False, + # TODO(lucas) figure out how to properly forward quant_method + #quant_config=self.o_proj.quant_method, + ) + + self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T) + else: + self.W_UV = W_UV + self.W_UK = W_UK + self.W_Q = W_Q.flatten(start_dim=1) + + @abstractmethod + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if output is not None: + raise NotImplementedError( + "output is not yet supported for MLAImplBase") + + is_decode = attn_metadata.decode_metadata is not None + is_prefill = attn_metadata.prefill_metadata is not None + + if (is_decode and is_prefill): + raise NotImplementedError( + "chunked prefill is not supported for MLAImplBase") + + # Restore head dim (for rotary embedding) + k_pe = k_pe.unsqueeze(1) + assert hasattr(attn_metadata, "input_positions") + + if is_decode: + q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) + q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ + .view(-1, self.num_heads, self.qk_rope_head_dim) + q_pe, k_pe = \ + self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe) + else: + assert is_prefill + q = self.q_proj(hidden_states_or_q_c)[0]\ + .view(-1, self.num_heads, self.qk_head_dim) + + # TODO(lucas): there must be a nicer way to write this line + q[..., self.qk_nope_head_dim:], k_pe = \ + self.rotary_emb( + attn_metadata.input_positions, + q[..., self.qk_nope_head_dim:], k_pe) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if attn_metadata.prefill_metadata is not None: + return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata) + + if attn_metadata.decode_metadata is not None: + return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata) + + # Optional common flash-attn based prefill + def _forward_prefill_flash( + self, + q: torch.Tensor, + k_c_normed: torch.Tensor, + k_pe: torch.Tensor, + seq_start_loc: torch.Tensor, + max_prefill_seq_len: int, + ) -> torch.Tensor: + + kv_nope = self.kv_b_proj(k_c_normed)[0]\ + .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + attn_output = flash_attn_varlen_func( + q=q, + k=k, + v=v_padded, + cu_seqlens_q=seq_start_loc, + cu_seqlens_k=seq_start_loc, + max_seqlen_q=max_prefill_seq_len, + max_seqlen_k=max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + ) + attn_output = attn_output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ + .reshape(-1, self.num_heads * v.shape[-1]) + + return self.o_proj(attn_output)[0] diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py new file mode 100644 index 0000000000000..da09bb70b4f1a --- /dev/null +++ b/vllm/attention/backends/triton_mla.py @@ -0,0 +1,749 @@ +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +from vllm.multimodal import MultiModalPlaceholderMap + +try: + from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + BatchDecodeMlaWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.ops.triton_decode_attention import decode_attention_fwd +from vllm.utils import async_tensor_h2d, make_tensor_with_pad + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + + +class TritonMLABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_impl_cls() -> Type["TritonMLAImpl"]: + return TritonMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return TritonMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]: + return TritonMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["TritonMLAState"]: + return TritonMLAState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + kv_lora_rank: int, # passed via head_size + ) -> Tuple[int, ...]: + # TODO(lucas): remove hardcoding k_pe size as 1/8th of kv_lora_rank + k_pe_size = kv_lora_rank // 8 + return (num_blocks, block_size, kv_lora_rank + k_pe_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [512] + + +class TritonMLAState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + self._positions = torch.zeros((max_batch_size, ), + dtype=torch.long, + device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._positions + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + input_positions=self._positions[:batch_size], + head_dim=self.runner.model_config.get_head_size()) + + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + "input_positions": attn_metadata.decode_metadata.input_positions, + } + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_positions = attn_metadata.input_positions + num_positions = input_positions.shape[0] + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + # CUDA graph buffer is padded so only perform a partial copy based on + # num_positions + input_buffers["input_positions"][:num_positions].copy_( + input_positions, non_blocking=True) + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + def begin_forward(self, model_input): + return + + +@dataclass +class TritonMLAMetadata(MLACommonMetadata): + """Metadata for TritonMLAMetadata. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["TritonMLAMetadata"] = None + _cached_decode_metadata: Optional["TritonMLAMetadata"] = None + + num_prefill_tokens: int + + num_kv_splits: int = 4 # TODO(lucas) add heuristic + attn_logits: Optional[torch.Tensor] = None + req_idx: Optional[torch.Tensor] = None + + # The dimension of the attention heads + head_dim: Optional[int] = None + + def __post_init__(self): + supported_head_sizes = TritonMLABackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + @property + def prefill_metadata(self) -> Optional["TritonMLAMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + input_positions = (None if self.input_positions is None else + self.input_positions[:self.num_prefill_tokens]) + + self._cached_prefill_metadata = TritonMLAMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + input_positions=input_positions, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + head_dim=self.head_dim) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["TritonMLAMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + input_positions = (None if self.input_positions is None else + self.input_positions[self.num_prefill_tokens:]) + + self._cached_decode_metadata = TritonMLAMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + input_positions=input_positions, + head_dim=self.head_dim) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.input_positions: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block, input_positions) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks, + inter_data.input_positions): + self.input_positions.extend(input_positions) + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + input_positions = async_tensor_h2d(self.input_positions, torch.long, + device, self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + num_kv_splits = 8 + + return TritonMLAMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + input_positions=input_positions, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + num_kv_splits=num_kv_splits, + head_dim=self.runner.model_config.get_head_size(), + ) + + +class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **kwargs) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **kwargs) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "TritonMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl") + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + attn_metadata: TritonMLAMetadata, + ) -> torch.Tensor: + assert isinstance(attn_metadata, TritonMLAMetadata) + return self._forward_prefill_flash(q, kv_c_normed, k_pe, + attn_metadata.seq_start_loc, + attn_metadata.max_prefill_seq_len) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: TritonMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 Triton MLA not yet supported") + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + # TODO(lucas) Allocate ahead of time + attn_logits = torch.empty( + ( + B, + self.num_heads, + attn_metadata.num_kv_splits, + # NOTE(lucas) idk why the +1 is here but sglang has it so we + # just mirror that + self.kv_lora_rank + 1, + ), + dtype=torch.float32, + device=q.device, + ) + + # Add a head dim of 1 + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) + kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + PAGE_SIZE = kv_c_and_k_pe_cache.size(1) + + # Run MQA + decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, attn_logits, + attn_metadata.num_kv_splits, self.scale, + PAGE_SIZE) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 84fe89b7df360..7f2fe7e831064 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -289,7 +289,9 @@ def __init__(self, runner: "ModelRunnerBase"): @contextmanager def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_slot_mapping = torch.full((max_batch_size, ), PAD_SLOT_ID, dtype=torch.long, @@ -299,7 +301,9 @@ def graph_capture(self, max_batch_size: int): device=self.runner.device) self._graph_block_tables = torch.from_numpy( self.runner.graph_block_tables).to(device=self.runner.device) + yield + self._is_graph_capturing = False del self._graph_slot_mapping del self._graph_seq_lens diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 962c45a65ae23..9b804a29a485d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -41,8 +41,10 @@ def __init__( blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, + use_mla: bool = False, prefix: str = "", attn_type: str = AttentionType.DECODER, + **extra_impl_args, ) -> None: super().__init__() if per_layer_sliding_window is not None: @@ -101,13 +103,18 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype, - block_size, is_attention_free, - blocksparse_params is not None) + attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + blocksparse_params is not None, + use_mla=use_mla) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type) + blocksparse_params, logits_soft_cap, attn_type, + **extra_impl_args) self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads @@ -193,6 +200,10 @@ def extra_repr(self) -> str: s += f", backend={self.impl.__class__.__name__}" return s + def process_weights_after_loading(self): + if hasattr(self.impl, "process_weights_after_loading"): + self.impl.process_weights_after_loading() + class MultiHeadAttention(nn.Module): """Multi-headed attention without any cache, used for ViT.""" diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py new file mode 100644 index 0000000000000..675df109b6c0e --- /dev/null +++ b/vllm/attention/ops/triton_decode_attention.py @@ -0,0 +1,667 @@ +# Adapted from +# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +# which was originally adapted from +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +# Changes: +# - Add support for page size >= 1. + +# Copyright 2025 vLLM Team +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Memory-efficient attention for decoding. +It supports page size >= 1. +""" + +import logging + +import triton +import triton.language as tl + +from vllm.platforms import current_platform + +is_hip_ = current_platform.is_rocm() + +logger = logging.getLogger(__name__) + +# TODO: Remove this when triton>=3.2.0. This issue will not affect performance +# and accuracy. +logger.warning( + "The following error message 'operation scheduled before its operands' " + "can be ignored.") + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q, mask=mask_d, other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = (kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + offs_d[None, :]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + offs_dv) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +def _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + BLOCK = 64 + NUM_KV_SPLITS = num_kv_splits + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + batch, head_num = q.shape[0], q.shape[1] + + grid = (batch, head_num, NUM_KV_SPLITS) + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + num_warps = 4 if kv_group_num == 1 else 2 + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + + _fwd_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-2), + k_buffer.stride(-1), + v_buffer.stride(-2), + v_buffer.stride(-1), + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, + Lv=Lv, + ) + + +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if kv_group_num > BLOCK_H: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[ + None, :] + q = tl.load(Q + offs_q, + mask=(mask_h[:, None]) & (mask_d[None, :]), + other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + + offs_dpe[None, :]) + qpe = tl.load(Q + off_qpe, + mask=(mask_h[:, None]) & (mask_dpe[None, :]), + other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + offs_d[:, None]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None]) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & + (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), + qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + offs_dv[None, :]) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + BLOCK = 32 + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + # [TODO] work around shmem limit on MI3xx + if is_hip_ and Lk >= 576: + BLOCK = 16 + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + batch, head_num = q.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + NUM_KV_SPLITS, + ) + + extra_kargs = {} + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = { + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-2), + k_buffer.stride(-1), + v_buffer.stride(-2), + v_buffer.stride(-1), + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=4, + num_stages=2, + Lk=Lk, + Lv=Lv, + **extra_kargs, + ) + + +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + o, + B_Seqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, + mask=mask_d, + other=0.0) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + v_buffer, + b_seq_len, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] + Lv = v_buffer.shape[-1] + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits + + extra_kargs = {} + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = { + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } + + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( + logits, + o, + b_seq_len, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + num_warps=4, + num_stages=2, + **extra_kargs, + ) + + +def decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + num_kv_splits) + + +def decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + num_kv_splits) + + +def decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size=1, + logit_cap=0.0, +): + assert num_kv_splits == attn_logits.shape[2] + kv_group_num = q.shape[1] // v_buffer.shape[-2] + + if kv_group_num == 1: + # MHA + decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + else: + # GQA/MQA/MLA + decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 1376274d57777..4c6bbc7272280 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -83,6 +83,7 @@ def get_attn_backend( block_size: int, is_attention_free: bool, is_blocksparse: bool = False, + use_mla: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong @@ -97,6 +98,7 @@ def get_attn_backend( is_attention_free=is_attention_free, is_blocksparse=is_blocksparse, use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, ) @@ -109,6 +111,7 @@ def _cached_get_attn_backend( is_attention_free: bool, is_blocksparse: bool = False, use_v1: bool = False, + use_mla: bool = False, ) -> Type[AttentionBackend]: if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") @@ -141,7 +144,8 @@ def _cached_get_attn_backend( # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( - selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1) + selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, + use_mla) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}") diff --git a/vllm/config.py b/vllm/config.py index 58464eae80b82..f6bd8b1ad8f14 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -736,17 +736,25 @@ def get_vocab_size(self) -> int: def get_hidden_size(self) -> int: return self.hf_text_config.hidden_size + @property + def is_deepseek_mla(self) -> bool: + # TODO add deepseek_v3 + return hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + in ('deepseek_v2')) + def get_head_size(self) -> int: # TODO remove hard code - if hasattr(self.hf_text_config, - "model_type") and (self.hf_text_config.model_type - in ('deepseek_v2', 'deepseek_v3')): - qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", - 0) - qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", - 0) - if qk_rope_head_dim and qk_nope_head_dim: - return qk_rope_head_dim + qk_nope_head_dim + if self.is_deepseek_mla: + if self.use_mla: + return self.hf_text_config.kv_lora_rank + else: + qk_rope_head_dim = getattr(self.hf_text_config, + "qk_rope_head_dim", 0) + qk_nope_head_dim = getattr(self.hf_text_config, + "qk_nope_head_dim", 0) + if qk_rope_head_dim and qk_nope_head_dim: + return qk_rope_head_dim + qk_nope_head_dim if self.is_attention_free: return 0 @@ -805,6 +813,10 @@ def get_total_num_kv_heads(self) -> int: def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: """Returns the number of KV heads per GPU.""" + if self.use_mla: + # When using MLA during decode it becomes MQA + return 1 + total_num_kv_heads = self.get_total_num_kv_heads() # If tensor parallelism is used, we divide the number of KV heads by # the tensor parallel size. We will replicate the KV heads in the @@ -955,6 +967,11 @@ def is_cross_encoder(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.is_cross_encoder_model(architectures) + @property + def use_mla(self) -> bool: + use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE) + return use_mla + @property def supported_runner_types(self) -> Set[RunnerType]: return {_TASK_RUNNER[task] for task in self.supported_tasks} diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1f203b6eaeb33..cc7c99e50ac4d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -931,7 +931,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default="auto", help='The worker class to use for distributed execution.') - parser.add_argument( "--generation-config", type=nullable_str, diff --git a/vllm/envs.py b/vllm/envs.py index 8627caec7790d..2a18e3b9bc51d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -77,6 +77,8 @@ V_SCALE_CONSTANT: int = 100 VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 + VLLM_MLA_DISABLE: bool = False + VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True def get_default_cache_root(): @@ -506,6 +508,18 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # TTFT and overall throughput. "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE": lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")), + + # If set, vLLM will disable the MLA attention optimizations. + "VLLM_MLA_DISABLE": + lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), + + # Flag that can control whether or not we perform matrix-absorption for MLA + # decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the + # matrices reduces the runtime FLOPs needed to compute MLA but requires + # storing more weights, W_Q_UK and W_UV_O, so can increase memory usage, + # the is enabled by default + "VLLM_MLA_PERFORM_MATRIX_ABSORPTION": + lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))) } # end-env-vars-definition diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 712266ee42639..62babcddd61b1 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -23,6 +23,7 @@ from transformers import AutoModelForCausalLM from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from vllm.attention import Attention from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig, VllmConfig, set_current_vllm_config) from vllm.distributed import (get_tensor_model_parallel_rank, @@ -397,6 +398,11 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) + elif isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # When attention modules need to process weights after + # currently only used by MLA + module.process_weights_after_loading() return model.eval() diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index af6810a140b43..73388cd269853 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -28,7 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -326,12 +326,156 @@ def forward( return output +class DeepseekV2MLAAttention(nn.Module): + """ + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + + For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py + """ + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + else: + hidden_states_or_q_c = hidden_states + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, + attn_metadata) + + class DeepseekV2DecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, prefix: str, + model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -344,7 +488,11 @@ def __init__( # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) - self.self_attn = DeepseekV2Attention( + if model_config.use_mla: + attn_cls = DeepseekV2MLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -421,6 +569,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config @@ -440,6 +589,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: DeepseekV2DecoderLayer( config, prefix, + model_config=model_config, cache_config=cache_config, quant_config=quant_config, ), diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 74948202cbe48..159ea94f99a27 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -31,7 +31,8 @@ def get_device_name(cls, device_id: int = 0) -> str: @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: if selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Using Torch SDPA backend.") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e4b436edf7588..91dcdff006e3e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -157,10 +157,14 @@ def get_current_memory_usage(cls, @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1) -> str: + kv_cache_dtype, block_size, use_v1, + use_mla) -> str: if use_v1: logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + if use_mla: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend" @@ -171,7 +175,8 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, pass elif selected_backend: raise ValueError( - f"Invalid attention backend for {cls.device_name}") + f"Invalid attention backend for {cls.device_name}, " + f"with use_v1: {use_v1} use_mla: {use_mla}") target_backend = _Backend.FLASH_ATTN if not cls.has_device_capability(80): diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index a32c262c84efa..0e1c4c0c5949f 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -27,7 +27,8 @@ class HpuPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: logger.info("Using HPUAttention backend.") return "vllm.attention.backends.hpu_attn.HPUAttentionBackend" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f2ecec3203fb7..186fa54bfc14c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -30,6 +30,7 @@ class _Backend(enum.Enum): TORCH_SDPA = enum.auto() OPENVINO = enum.auto() FLASHINFER = enum.auto() + TRITON_MLA = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() @@ -139,7 +140,8 @@ def is_cuda_alike(self) -> bool: @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: """Get the attention backend class of a device.""" return "" diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 7d414165a8188..3282c061714d3 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -30,7 +30,8 @@ class OpenVinoPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: if selected_backend != _Backend.OPENVINO: logger.info("Cannot use %s backend on OpenVINO.", selected_backend) logger.info("Using OpenVINO Attention backend.") diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 5ef56406e1935..8888521631481 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -75,7 +75,8 @@ class RocmPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1) -> str: + kv_cache_dtype, block_size, use_v1, + use_mla) -> str: selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if selected_backend == _Backend.ROCM_FLASH: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 05a3aa4305cfa..494a17633974d 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -29,7 +29,8 @@ class TpuPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: if selected_backend != _Backend.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Using Pallas backend.") diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c34b5b58672e7..a5ca77f57cf47 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -27,7 +27,8 @@ class XPUPlatform(Platform): @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool) -> str: + block_size: int, use_v1: bool, + use_mla: bool) -> str: if selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) logger.info("Using IPEX attention backend.") diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 7ccd4571b19df..08316ba74aad8 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -56,7 +56,8 @@ def __init__( model_config.dtype, cache_config.cache_dtype, self.block_size, - model_config.is_attention_free) + model_config.is_attention_free, + use_mla=model_config.use_mla) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bf1a40d48a789..160c0662ce976 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1066,6 +1066,7 @@ def __init__( self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, ) if needs_attn_backend else None if self.attn_backend: self.attn_state = self.attn_backend.get_state_cls()( @@ -1973,7 +1974,8 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) - self.input_buffers["positions"].copy_(positions, non_blocking=True) + if positions is not None: + self.input_buffers["positions"].copy_(positions, non_blocking=True) if self.backend_name != "NO_ATTENTION": self.input_buffers["slot_mapping"].copy_(