diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index d79a5ff00..cee7788c7 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -115,6 +115,32 @@ struct paged_kv_t { last_page_len(nullptr), rope_pos_offset(nullptr) {} + __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, + uint32_t batch_size, QKVLayout layout, DType* kv_data, + DType* k_data, DType* v_data, const int64_t * kv_strides, + IdType* indices, IdType* indptr, + IdType* last_page_len, IdType* rope_pos_offset = nullptr) + : num_heads(num_heads), + page_size(page_size), + head_dim(head_dim), + batch_size(batch_size), + indices(indices), + indptr(indptr), + last_page_len(last_page_len), + rope_pos_offset(rope_pos_offset) { + bool kv_defined = kv_data != nullptr; + if (kv_defined) { + this->k_data = kv_data; + this->v_data = kv_data + kv_strides[1]; + } else { + this->k_data = k_data; + this->v_data = v_data; + } + stride_page = kv_strides[0]; + stride_n = layout == QKVLayout::kHND ? kv_strides[2 + kv_defined] : kv_strides[1 + kv_defined]; + stride_h = layout == QKVLayout::kHND ? kv_strides[1 + kv_defined] : kv_strides[2 + kv_defined]; + } + /*! * \brief Construct a paged key-value cache * \param num_heads The number of heads diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 776c7c636..f246a938b 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -76,10 +76,13 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Run( CHECK_INPUT(q); CHECK_INPUT(qo_indptr); if (paged_kv_defined) { - CHECK_INPUT(paged_kv_cache.value()); + CHECK_CUDA(paged_kv_cache.value()); + CHECK_LAST_DIM_CONTIGUOUS(paged_kv_cache.value()); } else { - CHECK_INPUT(paged_k_cache.value()); - CHECK_INPUT(paged_v_cache.value()); + CHECK_CUDA(paged_k_cache.value()); + CHECK_LAST_DIM_CONTIGUOUS(paged_k_cache.value()); + CHECK_CUDA(paged_v_cache.value()); + CHECK_LAST_DIM_CONTIGUOUS(paged_v_cache.value()); } CHECK_INPUT(paged_kv_indptr); CHECK_INPUT(paged_kv_indices); @@ -163,6 +166,17 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Run( auto kv_scalar_type = paged_kv_defined ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); + + const int64_t* kv_cache_strides = nullptr; + if (paged_kv_cache.has_value()) { + kv_cache_strides = paged_kv_cache->strides().data(); + } else { + auto k_strides = paged_k_cache->strides(); + auto v_strides = paged_v_cache->strides(); + TORCH_CHECK(k_strides == v_strides, "k/v cache strides not match"); + kv_cache_strides = k_strides.data(); + } + if (q_scalar_type == kv_scalar_type) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, c_type, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { @@ -171,6 +185,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Run( static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), + kv_cache_strides, static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); @@ -214,6 +229,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Run( : nullptr), static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), + kv_cache_strides, static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index d6895041c..a637d472f 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -235,6 +235,9 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "'s last dim must be contiguous") + #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x)