Skip to content

Commit

Permalink
different prefill and decode scales
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson committed Jan 30, 2025
1 parent b0f7d3d commit 8cfa2f4
Showing 1 changed file with 1 addition and 9 deletions.
10 changes: 1 addition & 9 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
Expand Down Expand Up @@ -197,8 +196,6 @@ class TritonMLAMetadata(MLAMetadataCommon):
# The dimension of the attention heads
head_dim: Optional[int] = None

sm_scale: float = 0.0

def __post_init__(self):
supported_head_sizes = TritonMLABackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
Expand All @@ -207,11 +204,6 @@ def __post_init__(self):
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")

# Note(simon): for MLA: soft max scale needs to be
# `1 / sqrt(qk_nope_head_dim + qk_rope_head_dim)`.
assert self.head_dim is not None
self.sm_scale = 1.0 / math.sqrt(self.head_dim + self.head_dim // 8)

@property
def prefill_metadata(self) -> Optional["TritonMLAMetadata"]:
if self.num_prefills == 0:
Expand Down Expand Up @@ -684,7 +676,7 @@ def _forward_decode(
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits,
attn_metadata.num_kv_splits, decode_meta.sm_scale,
attn_metadata.num_kv_splits, self.scale,
PAGE_SIZE)

return self._v_up_proj_and_o_proj(o)

0 comments on commit 8cfa2f4

Please sign in to comment.