Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to maximize DMA bandwidth #13245

Merged
merged 12 commits into from
Feb 21, 2025

Conversation

lingfanyu
Copy link
Contributor

@lingfanyu lingfanyu commented Feb 13, 2025

Previous version of NKI flash attention kernel did not vectorize KV cache loading to fully utilize HBM bandwidth. As a result, the kernel is bottlenecked by fetching paged KV cache from HBM.

We apply vectorization in this PR to fully saturate DMA bandwidth.

@liangfu

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@JF-D JF-D left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

)
return

if nisa.get_nc_version() == nisa.nc_version.gen3:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these kernels targeting trn2? DMA transpose could bring better performance on trn2 onward.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it targets trn2. But here we are simplifying the code by removing the option to transpose v in kernel. So should_transpose_v is always False. We expect kernel input layout of value to be (batch, num_kv_head, seqlen_q, D)

@lingfanyu
Copy link
Contributor Author

Hi @simon-mo , could you please add me to Buildkite org so that I can unblock Neuron tests? Thanks!

Copy link
Contributor

@liangfu liangfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing!

To my understanding, the code change contains three main changes: 1/ mask reordering, 2/ vectorized KV cache loading, 3/ enable loading large block_tables. I think it's better to test these new capabilities individually, instead of bundle the integrated tests at higher level. But feel free to chime in.

).bool()
attn_mask = torch.concat([prior_mask_padded, active_mask_padded], dim=1)

# reorder_mask_outside = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean up?

],
)
@torch.inference_mode()
def test_flash_paged_attention_numerical(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(asking because i was expecting to be some what consistent with GPU test cases, like kernels/test_prefix_prefill.py, test_batch_prefill_kernels.py or test_page.py )

are we intentionally trying to rename the test function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted

block_size: int,
large_tile_size,
mixed_precision: bool,
reorder_mask_outside: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be better if we could have a separate test function for this new capability (e.g. reorder_mask_outside).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will separate it out in next PR #13455

"constant",
0,
)
assert LARGE_TILE_SZ >= B_P_SIZE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert with message ?

Comment on lines +53 to +58
def transform_block_tables_for_indirect_load(
block_tables,
block_size_tiling_factor,
num_head,
head_id,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be better if we could have a unit test for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit test has been added in tests/neuron/test_block_table.py

B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
qk_res_buffer=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qk_res_buffer is already None by default. no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Comment on lines 680 to 681
cur_k_tile[:, :] = nl.load(key[batch_id, head_id, :, :],
dtype=cur_k_tile.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loading while casting can be tricky. make it separate ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

context_kv_len = total_seq_len - total_query_len

B_P_SIZE = 128
# assuming LARGE_TILE_SIZE >= B_P_SIZE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add assertion?

Comment on lines 831 to 832
mask_reordered=True,
return_debug_tensors=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be more consistent with

flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
seqused_k=attn_metadata.seq_lens,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
)

remove both and set these internal argument inside the function call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment on lines 632 to 638
cur_mask = nl.load(
mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE),
])
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
],
dtype=mask.dtype,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loading while casting can be tricky. consider make it separate ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@lingfanyu
Copy link
Contributor Author

@liangfu Thanks for the review. I updated following your suggestions.

To my understanding, the code change contains three main changes: 1/ mask reordering, 2/ vectorized KV cache loading, 3/ enable loading large block_tables. I think it's better to test these new capabilities individually, instead of bundle the integrated tests at higher level. But feel free to chime in.

Will do in PR #13455

Copy link
Contributor

@liangfu liangfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update.

@simon-mo simon-mo merged commit 3317008 into vllm-project:main Feb 21, 2025
19 checks passed
kerthcet pushed a commit to kerthcet/vllm that referenced this pull request Feb 21, 2025
JenZhao pushed a commit to JenZhao/vllm that referenced this pull request Feb 21, 2025
michaelrglass pushed a commit to michaelrglass/vllm that referenced this pull request Feb 21, 2025
…ximize DMA bandwidth (vllm-project#13245)

Signed-off-by: Lingfan Yu <[email protected]>
Signed-off-by: Michael Glass <[email protected]>
@lingfanyu lingfanyu deleted the fast_vectorized_dma branch February 21, 2025 21:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants