Skip to content

Commit

Permalink
Use FP16-native PA after support in ROCm/aiter#97
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed Feb 7, 2025
1 parent c127e9a commit f8d2422
Showing 1 changed file with 10 additions and 19 deletions.
29 changes: 10 additions & 19 deletions vllm/attention/ops/paged_attn_ater.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,18 @@ def write_to_paged_cache(
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
# print(f"{key.shape=}{key.dtype}")
# print(f"{key_cache.shape=}{key_cache.dtype}")
# print(f"{value_cache.shape=}{value_cache.dtype}")
# print(f"{k_scale.shape=}{k_scale.dtype}")
# print(f"{v_scale.shape=}{v_scale.dtype}")
# print(f"{k_scale.numel()=}")
# print(f"{slot_mapping.flatten()=}")
dtypeDict = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp8': torch.int8, 'int8': torch.int8, 'auto': torch.float16 }
kvCacheDtype = dtypeDict[kv_cache_dtype]
if key_cache.dtype.itemsize == 1:
if "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)
else:
key_cache = key_cache.view(torch.int8)
value_cache = value_cache.view(torch.int8)
aiter.reshape_and_cache_with_pertoken_quant(
key,
value,
key_cache.view(kvCacheDtype),
value_cache.view(kvCacheDtype),
key_cache,
value_cache,
k_scale,
v_scale,
slot_mapping.flatten(),
Expand Down Expand Up @@ -156,14 +153,8 @@ def forward_decode(
elif "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)
else:
key_cache = key_cache.view(torch.int8)
value_cache = value_cache.view(torch.int8)
dtype=out.dtype
aiter.pa_fwd_asm(query.to(torch.bfloat16), key_cache, value_cache, block_tables, seq_lens, max_num_blocks_per_seq, k_scale, v_scale,out)
if dtype==torch.float16:
# aiter.pa_fwd_as only support bf16 output for now
out.copy_(out.view(torch.bfloat16).to(torch.float16))
aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, seq_lens,
max_num_blocks_per_seq, k_scale, v_scale,out)
return out

@staticmethod
Expand Down

0 comments on commit f8d2422

Please sign in to comment.