Skip to content

Commit

Permalink
[Bugfix] Add kv_scale input parameter to CPU backend (#3840)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Apr 4, 2024
1 parent 537ee25 commit 498eb5c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 5 deletions.
6 changes: 4 additions & 2 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch::Tensor &context_lens, int block_size,
int max_context_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype) {
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
Expand Down Expand Up @@ -734,7 +735,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch::Tensor &context_lens, int block_size,
int max_context_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype) {
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
Expand Down
4 changes: 3 additions & 1 deletion csrc/cpu/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches,
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
torch::Tensor &key_cache, torch::Tensor &value_cache,
torch::Tensor &slot_mapping,
const std::string &kv_cache_dtype) {
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);

int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
Expand Down
5 changes: 4 additions & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def forward(
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata,
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand All @@ -138,7 +139,8 @@ def forward(
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype)
attn_metadata.kv_cache_dtype,
kv_scale)

if attn_metadata.is_prompt:
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
Expand Down Expand Up @@ -199,6 +201,7 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)

# Reshape the output tensor.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward_decode(
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
kv_scale,
kv_scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)

Expand Down

0 comments on commit 498eb5c

Please sign in to comment.