From d4053c9764d4c264463a288a464c3bf9e71f9e58 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Fri, 16 Feb 2024 07:31:54 +0000 Subject: [PATCH 1/3] bugfix --- python/csrc/cascade.cu | 18 +++++++++--------- python/csrc/flashinfer_decl.h | 3 ++- python/csrc/page.cu | 7 ++++--- python/csrc/pytorch_extension_utils.h | 2 +- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/python/csrc/cascade.cu b/python/csrc/cascade.cu index 76fcbf879..4e1e379ee 100644 --- a/python/csrc/cascade.cu +++ b/python/csrc/cascade.cu @@ -44,11 +44,11 @@ std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, tor auto s_merged = torch::empty({seq_len, num_heads}, s_a.options()); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v_a.scalar_type(), c_type, [&] { - cudaError_t status = - MergeState(static_cast(v_a.data_ptr()), static_cast(s_a.data_ptr()), - static_cast(v_b.data_ptr()), static_cast(s_b.data_ptr()), - static_cast(v_merged.data_ptr()), - static_cast(s_merged.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream); + cudaError_t status = MergeState( + static_cast(v_a.data_ptr()), static_cast(s_a.data_ptr()), + static_cast(v_b.data_ptr()), static_cast(s_b.data_ptr()), + static_cast(v_merged.data_ptr()), static_cast(s_merged.data_ptr()), + seq_len, num_heads, head_dim, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "MergeState kernel launch failed: ", cudaGetErrorString(status)); return true; @@ -80,10 +80,10 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v.scalar_type(), c_type, [&] { - cudaError_t status = - MergeStateInPlace(static_cast(v.data_ptr()), static_cast(s.data_ptr()), - static_cast(v_other.data_ptr()), - static_cast(s_other.data_ptr()), seq_len, num_heads, head_dim, torch_current_stream); + cudaError_t status = MergeStateInPlace( + static_cast(v.data_ptr()), static_cast(s.data_ptr()), + static_cast(v_other.data_ptr()), static_cast(s_other.data_ptr()), seq_len, + num_heads, head_dim, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status)); return true; diff --git a/python/csrc/flashinfer_decl.h b/python/csrc/flashinfer_decl.h index a9144616d..0a85d49ac 100644 --- a/python/csrc/flashinfer_decl.h +++ b/python/csrc/flashinfer_decl.h @@ -24,6 +24,7 @@ template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched< \ PageStorage::kIndices, LAYOUT, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, \ CAUSAL, T, T, int32_t>(BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, \ + IdType* q_rope_position, \ paged_kv_t paged_kv, T* o, \ float* lse, float rope_scale, float rope_theta, cudaStream_t stream); \ } @@ -64,7 +65,7 @@ template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, + BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, paged_kv_t paged_kv, DTypeOut* o, float* lse, float rope_scale, float rope_theta, cudaStream_t stream); diff --git a/python/csrc/page.cu b/python/csrc/page.cu index 391576cd5..b71751cf3 100644 --- a/python/csrc/page.cu +++ b/python/csrc/page.cu @@ -73,9 +73,10 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, num_heads, page_size, head_dim, batch_size, static_cast(kv_data.data_ptr()), static_cast(kv_indices.data_ptr()), static_cast(kv_indptr.data_ptr()), static_cast(kv_last_page_len.data_ptr())); - cudaError_t status = AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), - static_cast(append_value.data_ptr()), - static_cast(append_indptr.data_ptr()), torch_current_stream); + cudaError_t status = + AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), + static_cast(append_value.data_ptr()), + static_cast(append_indptr.data_ptr()), torch_current_stream); TORCH_CHECK(status == cudaSuccess, "AppendPagedKVCache failed with error: ", cudaGetErrorString(status)); return true; diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 780fe89a2..54c3aba8a 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ #pragma once -#include #include +#include #include "generated/dispatch.inc" From dd7ea600168d8464d18cf837456bcc1181e6de2d Mon Sep 17 00:00:00 2001 From: yzh119 Date: Fri, 16 Feb 2024 07:37:15 +0000 Subject: [PATCH 2/3] bugfix --- python/csrc/flashinfer_decl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/csrc/flashinfer_decl.h b/python/csrc/flashinfer_decl.h index 0a85d49ac..7e3950c9c 100644 --- a/python/csrc/flashinfer_decl.h +++ b/python/csrc/flashinfer_decl.h @@ -24,7 +24,7 @@ template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched< \ PageStorage::kIndices, LAYOUT, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, \ CAUSAL, T, T, int32_t>(BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, \ - IdType* q_rope_position, \ + int32_t* q_rope_position, \ paged_kv_t paged_kv, T* o, \ float* lse, float rope_scale, float rope_theta, cudaStream_t stream); \ } From 4ee0253273587f39ab5d4c5edd4a06aabfac8778 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Fri, 16 Feb 2024 11:23:22 +0000 Subject: [PATCH 3/3] bugfix --- include/flashinfer/wrapper.cuh | 5 +++-- python/csrc/batch_prefill.cu | 1 + python/csrc/flashinfer_decl.h | 22 +++++++++++----------- python/flashinfer/__init__.py | 1 + python/flashinfer/cascade.py | 1 + python/flashinfer/decode.py | 7 ++++--- python/flashinfer/page.py | 1 + python/flashinfer/prefill.py | 1 + python/flashinfer/utils.py | 1 + python/setup.py | 1 + python/tests/test_batch_decode_kernels.py | 1 + python/tests/test_batch_prefill_kernels.py | 1 + python/tests/test_shared_prefix_kernels.py | 1 + 13 files changed, 28 insertions(+), 16 deletions(-) diff --git a/include/flashinfer/wrapper.cuh b/include/flashinfer/wrapper.cuh index 64cd09560..dd508127e 100644 --- a/include/flashinfer/wrapper.cuh +++ b/include/flashinfer/wrapper.cuh @@ -207,8 +207,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( return BatchPrefillWithRaggedKVCacheWrapperDispatched< GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, k, v, kv_indptr, o, lse, batch_size, - num_kv_heads, rope_scale, rope_theta, stream); + handler, q, qo_indptr, k, v, kv_indptr, /*q_rope_position=*/nullptr, + /*k_rope_pos_offset=*/nullptr, o, lse, batch_size, num_kv_heads, + rope_scale, rope_theta, stream); })})})})})}); return cudaSuccess; } diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 18dc6a91c..bea6b9bcf 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -216,6 +216,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( &handler_, static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(kv_indptr.data_ptr()), + /*q_rope_position=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, batch_size, num_kv_heads, rope_scale, rope_theta, diff --git a/python/csrc/flashinfer_decl.h b/python/csrc/flashinfer_decl.h index 7e3950c9c..5ae1c2514 100644 --- a/python/csrc/flashinfer_decl.h +++ b/python/csrc/flashinfer_decl.h @@ -29,14 +29,14 @@ float* lse, float rope_scale, float rope_theta, cudaStream_t stream); \ } -#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \ - LAYOUT, ROTARY_MODE) \ - namespace flashinfer { \ - template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \ - GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \ - BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \ - T* o, float* lse, uint32_t batch_size, uint32_t num_kv_heads, float rope_scale, \ - float rope_theta, cudaStream_t stream); \ +#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \ + LAYOUT, ROTARY_MODE) \ + namespace flashinfer { \ + template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \ + GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \ + BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \ + int32_t* q_rope_position, int32_t* k_rope_pos_offset, T* o, float* lse, uint32_t batch_size, \ + uint32_t num_kv_heads, float rope_scale, float rope_theta, cudaStream_t stream); \ } #define INST_SinglePrefill(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, LAYOUT, \ @@ -57,9 +57,9 @@ template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, - IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size, - const uint32_t num_kv_heads, const float rope_scale, const float rope_theta, - cudaStream_t stream); + IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, + const uint32_t batch_size, const uint32_t num_kv_heads, const float rope_scale, + const float rope_theta, cudaStream_t stream); template