From 834ad75f931e6a34366e062b545d1827e61efcef Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sat, 22 Feb 2025 16:17:41 -0800 Subject: [PATCH] fix updater prefill --- python/sglang/srt/configs/model_config.py | 1 + .../layers/attention/flashinfer_mla_backend.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index ab4265ee480..c7f6743bdb6 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -105,6 +105,7 @@ def __init__( self.kv_lora_rank = self.hf_config.kv_lora_rank self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim + self.v_head_dim = self.hf_config.v_head_dim elif "MiniCPM3ForCausalLM" in self.hf_config.architectures: self.head_dim = 128 self.attention_arch = AttentionArch.MLA diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 3ee9f78a8a4..55fe91542dd 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -420,7 +420,10 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.num_kv_heads = model_runner.model_config.get_num_kv_heads( get_attention_tp_size() ) - self.head_dim = model_runner.model_config.head_dim + self.kv_lora_rank = model_runner.model_config.kv_lora_rank + self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim + self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim + self.v_head_dim = model_runner.model_config.v_head_dim self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype self.attn_backend = attn_backend @@ -502,6 +505,7 @@ def call_begin_forward( qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] + sm_scale = 1.0 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) # extend part if use_ragged: @@ -510,8 +514,8 @@ def call_begin_forward( kv_indptr=qo_indptr, num_qo_heads=self.num_qo_heads, num_kv_heads=self.num_kv_heads, - head_dim_qk=192, - head_dim_vo=128, + head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, + head_dim_vo=self.v_head_dim, q_data_type=self.q_data_type, ) @@ -524,12 +528,12 @@ def call_begin_forward( kv_indices, kv_len_arr, self.num_qo_heads, - 512, - 64, + self.kv_lora_rank, + self.qk_rope_head_dim, 1, True, - 1 / math.sqrt(192), - self.data_type, + sm_scale, + self.q_data_type, self.data_type, )