Skip to content

Commit

Permalink
fix updater prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Feb 23, 2025
1 parent d9913bb commit 834ad75
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
1 change: 1 addition & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
self.head_dim = 128
self.attention_arch = AttentionArch.MLA
Expand Down
18 changes: 11 additions & 7 deletions python/sglang/srt/layers/attention/flashinfer_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,10 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.head_dim = model_runner.model_config.head_dim
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.attn_backend = attn_backend
Expand Down Expand Up @@ -502,6 +505,7 @@ def call_begin_forward(

qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
sm_scale = 1.0 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)

# extend part
if use_ragged:
Expand All @@ -510,8 +514,8 @@ def call_begin_forward(
kv_indptr=qo_indptr,
num_qo_heads=self.num_qo_heads,
num_kv_heads=self.num_kv_heads,
head_dim_qk=192,
head_dim_vo=128,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
)

Expand All @@ -524,12 +528,12 @@ def call_begin_forward(
kv_indices,
kv_len_arr,
self.num_qo_heads,
512,
64,
self.kv_lora_rank,
self.qk_rope_head_dim,
1,
True,
1 / math.sqrt(192),
self.data_type,
sm_scale,
self.q_data_type,
self.data_type,
)

Expand Down

0 comments on commit 834ad75

Please sign in to comment.