-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][Feature] Support KV Partition for BatchPrefill kernel for Paged & Ragged KV-Cache. #75
Conversation
551858a
to
bf6e4dc
Compare
@yzh119 is this PR good to use? This would be extremely useful for some of my work. |
@AgrawalAmey We did a huge amount of code refactor since the last commit of this PR, so I need to rebase and add some new commits, please stay tuned :) |
@yzh119 looking forward to it! I would be happy to help accelerate this, please let me know if I can help in any way. |
Looking forward to it!! |
@yzh119 Typing to ask if this is ready for use? I just find |
Moved to #310 |
@chenzhuofu @ZSL98 @AgrawalAmey |
Amazing, thanks a lot for the awesome work! 🙏 |
Duplicate of #75, but re-based on the main branch. Note that to support CUDAGraph, we cannot make `kv_chunk_size` a function argument, which will be passed by value, and cannot change once captured by CUDAGraph. Instead, we pass `kv_chunk_size` through a `kv_chunk_size_ptr` which is a pointer to a global memory address that stores the `kv_chunk_size`, its value can be set in `BeginForward` fuctions.
Before this PR, FlashInfer supports KV sequence parallelism for single decode/prefill and batch decode, but not batch prefill, however, this feature is also important for batch prefill kernel. This PR implements KV partition for batch prefill kernels (on both Paged & Ragged KV-Cache).