-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to maximize DMA bandwidth #13245
Conversation
Signed-off-by: Lingfan Yu <[email protected]>
Signed-off-by: Lingfan Yu <[email protected]>
Signed-off-by: Lingfan Yu <[email protected]>
Signed-off-by: Lingfan Yu <[email protected]>
Signed-off-by: Lingfan Yu <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Lingfan Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
) | ||
return | ||
|
||
if nisa.get_nc_version() == nisa.nc_version.gen3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these kernels targeting trn2? DMA transpose could bring better performance on trn2 onward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it targets trn2. But here we are simplifying the code by removing the option to transpose v in kernel. So should_transpose_v
is always False
. We expect kernel input layout of value to be (batch, num_kv_head, seqlen_q, D)
Hi @simon-mo , could you please add me to Buildkite org so that I can unblock Neuron tests? Thanks! |
Signed-off-by: Lingfan Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing!
To my understanding, the code change contains three main changes: 1/ mask reordering, 2/ vectorized KV cache loading, 3/ enable loading large block_tables. I think it's better to test these new capabilities individually, instead of bundle the integrated tests at higher level. But feel free to chime in.
tests/neuron/test_prefix_prefill.py
Outdated
).bool() | ||
attn_mask = torch.concat([prior_mask_padded, active_mask_padded], dim=1) | ||
|
||
# reorder_mask_outside = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean up?
tests/neuron/test_prefix_prefill.py
Outdated
], | ||
) | ||
@torch.inference_mode() | ||
def test_flash_paged_attention_numerical( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(asking because i was expecting to be some what consistent with GPU test cases, like kernels/test_prefix_prefill.py, test_batch_prefill_kernels.py or test_page.py )
are we intentionally trying to rename the test function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted
tests/neuron/test_prefix_prefill.py
Outdated
block_size: int, | ||
large_tile_size, | ||
mixed_precision: bool, | ||
reorder_mask_outside: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be better if we could have a separate test function for this new capability (e.g. reorder_mask_outside
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will separate it out in next PR #13455
tests/neuron/test_prefix_prefill.py
Outdated
"constant", | ||
0, | ||
) | ||
assert LARGE_TILE_SZ >= B_P_SIZE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert with message ?
def transform_block_tables_for_indirect_load( | ||
block_tables, | ||
block_size_tiling_factor, | ||
num_head, | ||
head_id, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be better if we could have a unit test for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unit test has been added in tests/neuron/test_block_table.py
vllm/attention/ops/nki_flash_attn.py
Outdated
B_P_SIZE=B_P_SIZE, | ||
B_F_SIZE=B_F_SIZE, | ||
B_D_SIZE=B_D_SIZE, | ||
qk_res_buffer=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qk_res_buffer
is already None by default. no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
vllm/attention/ops/nki_flash_attn.py
Outdated
cur_k_tile[:, :] = nl.load(key[batch_id, head_id, :, :], | ||
dtype=cur_k_tile.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loading while casting can be tricky. make it separate ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
vllm/attention/ops/nki_flash_attn.py
Outdated
context_kv_len = total_seq_len - total_query_len | ||
|
||
B_P_SIZE = 128 | ||
# assuming LARGE_TILE_SIZE >= B_P_SIZE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add assertion?
vllm/attention/ops/nki_flash_attn.py
Outdated
mask_reordered=True, | ||
return_debug_tensors=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to be more consistent with
vllm/vllm/v1/attention/backends/flash_attn.py
Lines 197 to 213 in 4c82229
flash_attn_varlen_func( | |
q=query[:num_actual_tokens], | |
k=key_cache, | |
v=value_cache, | |
out=output[:num_actual_tokens], | |
cu_seqlens_q=attn_metadata.query_start_loc, | |
max_seqlen_q=attn_metadata.max_query_len, | |
seqused_k=attn_metadata.seq_lens, | |
max_seqlen_k=attn_metadata.max_seq_len, | |
softmax_scale=self.scale, | |
causal=True, | |
alibi_slopes=self.alibi_slopes, | |
window_size=self.sliding_window, | |
block_table=attn_metadata.block_table, | |
softcap=self.logits_soft_cap, | |
fa_version=self.vllm_flash_attn_version, | |
) |
remove both and set these internal argument inside the function call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
vllm/attention/ops/nki_flash_attn.py
Outdated
cur_mask = nl.load( | ||
mask[ | ||
nl.ds(i * B_P_SIZE, B_P_SIZE), | ||
nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE), | ||
]) | ||
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ), | ||
], | ||
dtype=mask.dtype, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loading while casting can be tricky. consider make it separate ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: Lingfan Yu <[email protected]>
Signed-off-by: Lingfan Yu <[email protected]>
@liangfu Thanks for the review. I updated following your suggestions.
Will do in PR #13455 |
Signed-off-by: Lingfan Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update.
…ximize DMA bandwidth (vllm-project#13245) Signed-off-by: Lingfan Yu <[email protected]>
…ximize DMA bandwidth (vllm-project#13245) Signed-off-by: Lingfan Yu <[email protected]>
…ximize DMA bandwidth (vllm-project#13245) Signed-off-by: Lingfan Yu <[email protected]> Signed-off-by: Michael Glass <[email protected]>
Previous version of NKI flash attention kernel did not vectorize KV cache loading to fully utilize HBM bandwidth. As a result, the kernel is bottlenecked by fetching paged KV cache from HBM.
We apply vectorization in this PR to fully saturate DMA bandwidth.
@liangfu