Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Kernel] allow non-power-of-two head sizes in prefix prefill #4128

Merged
merged 4 commits into from
Apr 18, 2024

Conversation

mmoskal
Copy link
Contributor

@mmoskal mmoskal commented Apr 16, 2024

The existing prefix prefill kernel only supports head dimension that is a power of two. This due to Triton only supporting power of two block sizes. This PR enlarges the Q,K,V tensors to the next power of two and pads them with zeros when reading (and writing).

It doesn't seem to affect performance of the non-padded case.

CC @rkooo567

FIX #4127


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!! Didn't know there was such limitation. Is there any other known limitation from this kernel ?

@@ -195,7 +201,8 @@ def _fwd_kernel(
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
mask=dim_mask[None, :] &
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition seems to be shared across all conditions, is this correct? If so, should we create a separate mask that combines 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conditions are actually subtly different for each tl.load(). My limited understanding of Triton is that it will inline the condition in the load operation, so I think it's better to keep two separate 1D tensors rather the pre-computing the 2D matrix anyways.

@mmoskal
Copy link
Contributor Author

mmoskal commented Apr 17, 2024

Thank you for quick review!

I haven't run into any other limitations.

@mmoskal mmoskal requested a review from rkooo567 April 17, 2024 17:36
Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for me if tests pass. Probably need the original author of the kernel to take a look (was the person you? or do you happen to know?)

@mmoskal
Copy link
Contributor Author

mmoskal commented Apr 17, 2024

Looks like the original author is @caoshiyi

@mmoskal
Copy link
Contributor Author

mmoskal commented Apr 17, 2024

I synced the fork, that should fix the spec-decoding tests I think

@rkooo567
Copy link
Collaborator

cc @caoshiyi can you have a quick look at the PR?

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao @WoosukKwon can you help take a look? My skim is this is correct but my triton experience is limited.

@caoshiyi
Copy link
Contributor

Looks good to me if the tests passed. Thanks for fixing this!

@simon-mo
Copy link
Collaborator

image

Test passed. Merging. Thanks @caoshiyi @rkooo567 and @mmoskal

@simon-mo simon-mo merged commit e8cc796 into vllm-project:main Apr 18, 2024
44 of 46 checks passed
@mmoskal mmoskal deleted the prefill-pow-2 branch April 18, 2024 16:41
@DefTruth DefTruth mentioned this pull request Apr 19, 2024
9 tasks
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request Apr 22, 2024
@mmoskal
Copy link
Contributor Author

mmoskal commented Apr 25, 2024

@rkooo567 it looks like this kernel (and also rocm_flash_attn) do not support sliding window attention. I guess if one does attention in chunks, one can drop the blocks that are out of window in the block manager (cf #3665) which takes care of most of tokens that are out of window, but not all (also depending on chunk size).

I suspect having a slightly bigger window every now and then shouldn't be bad for model overall performance, but it would definitely affect the exact results.

CC @simon-mo @cadedaniel

@rkooo567
Copy link
Collaborator

Feel like it is better we have a legit support in kernels... @mmoskal is this complicated to support? Also a relevant issue; #4057

@mmoskal
Copy link
Contributor Author

mmoskal commented Apr 25, 2024

@rkooo567 hard to say. This PR was my first ever Triton code... But I will try to at least estimate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug][Chunked prefill]: head size has to be power of two
4 participants