diff --git a/README.md b/README.md index 5b10a968c..f97f3bdd9 100644 --- a/README.md +++ b/README.md @@ -72,13 +72,13 @@ num_qo_heads = 32 q = torch.randn(num_qo_heads, head_dim).half().to(0) o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly -o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, rotary_mode="LLAMA") # decode with LLaMA style RoPE on-the-fly +o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="LLAMA") # decode with LLaMA style RoPE on-the-fly # append attention append_qo_len = 128 q = torch.randn(append_qo_len, num_qo_heads, head_dim).half().to(0) # append attention, the last 128 tokens in the KV-Cache are the new tokens o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly, apply causal mask -o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, rotary_mode="LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask +o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask # prefill attention qo_len = 2048 diff --git a/include/flashinfer.cuh b/include/flashinfer.cuh index c36f1467a..b16b7b38c 100644 --- a/include/flashinfer.cuh +++ b/include/flashinfer.cuh @@ -19,6 +19,6 @@ #include "flashinfer/attention.cuh" #include "flashinfer/layout.cuh" #include "flashinfer/page.cuh" -#include "flashinfer/rope.cuh" +#include "flashinfer/pos_enc.cuh" #endif // FLASHINFER_CUH_ diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index f42cb3f93..ca128beec 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -32,7 +32,7 @@ #include "../layout.cuh" #include "../math.cuh" #include "../page.cuh" -#include "../rope.cuh" +#include "../pos_enc.cuh" #include "../utils.cuh" #include "../vec_dtypes.cuh" #include "cascade.cuh" @@ -48,7 +48,7 @@ namespace { /*! * \brief Load k tile from smem and compute qk - * \tparam rotary_mode The rotary mode used in the kernel + * \tparam pos_encoding_mode The positional encoding mode used in the kernel * \tparam head_dim A template integer indicates the head dimension * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension @@ -65,18 +65,20 @@ namespace { * \param s A float indicates the thread-local result of qk * \param st The self-attention state to be updated */ -template +template __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage_idx, const vec_t& q_vec, const vec_t& freq, uint32_t kv_idx_base, - uint32_t iter_base, uint32_t iter_bound, float* s, + uint32_t iter_base, uint32_t iter_bound, + const int32_t q_offset, float alibi_slope, float* s, state_t& st) { uint32_t tx = threadIdx.x, tz = threadIdx.z; float m_prev = st.m; #pragma unroll for (uint32_t j = 0; j < tile_size; ++j) { vec_t k_vec; - if constexpr (rotary_mode == RotaryMode::kLlama) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { // apply rotary embedding for all rows in k matrix of kv-cache k_vec = vec_apply_llama_rope(smem + j * bdx * vec_size, freq, kv_idx_base + tz * tile_size + j); @@ -94,6 +96,9 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage s[j] += math::shfl_xor_sync(s[j], offset); } s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4; + if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { + s[j] += alibi_slope * float(int(kv_idx_base + tz * tile_size + j) - q_offset); + } st.m = max(st.m, s[j]); } @@ -175,7 +180,7 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f * \brief FlashAttention decoding cuda kernel with kv-cache for a single request * \tparam kv_layout The layout of k/v matrices (NHD or HND) * \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not - * \tparam rotary_mode The rotary mode + * \tparam pos_encoding_mode The positional encoding mode * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam bdy A template integer indicates the block size in y dimension @@ -196,9 +201,9 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f * of "theta" used in RoPE (Rotary Positional Embeddings) * \param kv_chunk_size A integer indicates the kv-chunk size */ -template +template __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, DTypeOut* __restrict__ o, DTypeOut* __restrict__ tmp, @@ -215,6 +220,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* uint32_t kv_chunk_idx = blockIdx.x; uint32_t num_kv_chunks = gridDim.x; uint32_t num_qo_heads = info.get_num_qo_heads(); + const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; uint32_t seq_len = info.kv_len; extern __shared__ uint8_t smem[]; @@ -227,7 +233,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; vec_t freq; - if constexpr (rotary_mode == RotaryMode::kLlama) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { freq[i] = rope_rcp_scale * @@ -290,10 +296,10 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); - compute_qk( + compute_qk( k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec, - freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, s, - st_local); + freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, + seq_len - 1, alibi_slope, s, st_local); block.sync(); // load k for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { @@ -349,8 +355,9 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* } } -template +template __global__ void BatchDecodeWithPaddedKVCacheKernel( DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, DTypeOut* __restrict__ o, float* __restrict__ lse, @@ -365,6 +372,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( uint32_t batch_idx = blockIdx.x; uint32_t num_qo_heads = info.get_num_qo_heads(); uint32_t num_kv_heads = info.get_num_kv_heads(); + const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; uint32_t seq_len = info.kv_len; extern __shared__ uint8_t smem[]; @@ -375,7 +383,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; vec_t freq; - if constexpr (rotary_mode == RotaryMode::kLlama) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { freq[i] = rope_rcp_scale * @@ -429,9 +437,9 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); - compute_qk(k_smem + (stage_idx * bdz + tz) * bdy * head_dim, - stage_idx, q_vec, freq, consumer_kv_idx_base, - iter * bdy * bdz, seq_len, s, st_local); + compute_qk( + k_smem + (stage_idx * bdz + tz) * bdy * head_dim, stage_idx, q_vec, freq, + consumer_kv_idx_base, iter * bdy * bdz, seq_len, seq_len - 1, alibi_slope, s, st_local); block.sync(); // load k cp_async::pred_load( @@ -481,7 +489,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( /*! * \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests * \tparam partition_kv Whether to partition kv-cache on sequence length dimension or not - * \tparam rotary_mode The rotary mode + * \tparam pos_encoding_mode The positional encoding mode * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam bdy A template integer indicates the block size in y dimension @@ -502,12 +510,12 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( * \param rope_rcp_theta A floating number indicate the reciprocal * of "theta" used in RoPE (Rotary Positional Embeddings) */ -template __global__ void BatchDecodeWithPagedKVCacheKernel( - DTypeIn* __restrict__ q, IdType* __restrict__ q_rope_position, + DTypeIn* __restrict__ q, IdType* __restrict__ q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, DTypeOut* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float rope_rcp_scale, @@ -520,6 +528,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( const uint32_t kv_head_idx = blockIdx.y; const uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; const uint32_t num_qo_heads = gridDim.y * bdy; + const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; const uint32_t cur_chunk_start = partition_kv ? kv_partition_info.chunk_start_pos[batch_idx] : 0U; const uint32_t cur_page_indptr_begin = paged_kv.indptr[batch_idx], cur_page_indptr_end = paged_kv.indptr[batch_idx + 1]; @@ -546,7 +555,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; vec_t freq; - if constexpr (rotary_mode == RotaryMode::kLlama) { + int32_t q_offset_val = q_offset == nullptr ? (seq_len - 1) : q_offset[mapped_batch_idx]; + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { freq[i] = rope_rcp_scale * @@ -555,8 +565,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( } // apply rotary embedding to q matrix q_vec = vec_apply_llama_rope( - q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, - q_rope_position == nullptr ? (seq_len - 1) : q_rope_position[mapped_batch_idx]); + q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, q_offset_val); } else { // do not apply rotary embedding to q matrix q_vec.cast_load(q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); @@ -631,12 +640,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); - compute_qk( + compute_qk( k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec, freq, (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) + cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz, - iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, s, st); + iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, q_offset_val, alibi_slope, s, st); block.sync(); #pragma unroll @@ -732,17 +741,15 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo * \param seq_len A integer indicates the sequence length * \param head_dim A integer indicates the head dimension * \param kv_layout The layout of k/v matrices - * \param rotary_mode The rotary mode + * \param pos_encoding_mode The positional encoding mode * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ template -cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t& max_grid_size, - uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t seq_len, uint32_t head_dim, - QKVLayout kv_layout = QKVLayout::kNHD, - RotaryMode rotary_mode = RotaryMode::kNone, - cudaStream_t stream = nullptr) { +cudaError_t SingleDecodeWithKVCacheWorkEstimation( + uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t seq_len, uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, cudaStream_t stream = nullptr) { const uint32_t GROUP_SIZE = num_qo_heads / num_kv_heads; if (seq_len <= 256U) { tmp_size = 0; @@ -751,8 +758,8 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t& num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, - {DISPATCH_ROTARY_MODE( - rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { + {DISPATCH_POS_ENCODING_MODE( + pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; @@ -767,10 +774,9 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t& head_dim * sizeof(DTypeIn) + 2U * bdy * bdz * sizeof(float); - auto kernel = - SingleDecodeWithKVCacheKernel; + auto kernel = SingleDecodeWithKVCacheKernel< + KV_LAYOUT, /*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, + tile_size_per_bdx, vec_size, bdx, bdy, bdz, DTypeIn, DTypeOut>; int num_blocks_per_sm = 0; int num_sm = 0; int dev_id = 0; @@ -806,7 +812,7 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t& * \param seq_len A integer indicates the sequence length * \param head_dim A integer indicates the head dimension * \param kv_layout The layout of q/k/v matrices - * \param rotary_mode The rotary mode + * \param pos_encoding_mode The positional encoding mode * \param rope_scale The scaling factor used in RoPE Interpolation * \param rope_theta The theta used in RoPE * \param stream The cuda stream to launch the kernel @@ -816,7 +822,7 @@ template cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len, uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, - RotaryMode rotary_mode = RotaryMode::kNone, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { @@ -834,8 +840,8 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, - {DISPATCH_ROTARY_MODE( - rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { + {DISPATCH_POS_ENCODING_MODE( + pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; @@ -852,10 +858,9 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut 2U * bdy * bdz * sizeof(float); if (seq_len <= 256 || tmp == nullptr) { // no need to use partition-kv kernel - auto kernel = - SingleDecodeWithKVCacheKernel; + auto kernel = SingleDecodeWithKVCacheKernel< + KV_LAYOUT, /*partition_kv=*/false, POS_ENCODING_MODE, num_stages_smem, + tile_size_per_bdx, vec_size, bdx, bdy, bdz, DTypeIn, DTypeOut>; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -875,10 +880,9 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { // use partition-kv kernel - auto kernel = - SingleDecodeWithKVCacheKernel; + auto kernel = SingleDecodeWithKVCacheKernel< + KV_LAYOUT, /*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, + tile_size_per_bdx, vec_size, bdx, bdy, bdz, DTypeIn, DTypeOut>; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -1057,7 +1061,7 @@ std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatc * \param new_batch_size The new batch size after the partition * \param paged_kv The paged kv cache data structure * \param num_qo_heads A integer indicates the number of heads of query and output - * \param rotary_mode The rotary mode + * \param pos_encoding_mode The positional encoding mode * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ @@ -1067,11 +1071,12 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim, const uint32_t page_size, - const RotaryMode rotary_mode = RotaryMode::kNone, cudaStream_t stream = nullptr) { + const PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + cudaStream_t stream = nullptr) { DISPATCH_GQA_GROUP_SIZE( num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_HEAD_DIM( - head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { + head_dim, HEAD_DIM, {DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, { constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; @@ -1087,8 +1092,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( 2 * bdy * bdz * sizeof(float)); auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel< - /*partition_kv=*/true, ROTARY_MODE, num_stages_smem, tile_size_per_bdx, vec_size, - bdx, bdy, bdz, page_storage, kv_layout, DTypeIn, DTypeOut, IdType>; + /*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, + vec_size, bdx, bdy, bdz, page_storage, kv_layout, DTypeIn, DTypeOut, IdType>; int num_blocks_per_sm = 0; int num_sm = 0; int dev_id = 0; @@ -1131,10 +1136,9 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation( } template + PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType> cudaError_t BatchDecodeWithPagedKVCacheDispatched( - DTypeIn* q, IdType* q_rope_position, - paged_kv_t paged_kv, + DTypeIn* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { const float rope_rcp_scale = 1.f / rope_scale; @@ -1160,13 +1164,13 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( dim3 nblks(batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); auto kernel = - BatchDecodeWithPagedKVCacheKernel; + BatchDecodeWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, - (void*)&q_rope_position, + (void*)&q_offset, (void*)&paged_kv, (void*)&kv_partition_info, (void*)&o, @@ -1179,13 +1183,13 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( } else { // use partition-kv kernel auto partition_kv_kernel = - BatchDecodeWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, - (void*)&q_rope_position, + (void*)&q_offset, (void*)&paged_kv, (void*)&kv_partition_info, (void*)&o, @@ -1218,7 +1222,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( * \param tmp Used-allocated temporary buffer * \param lse The logsumexp values. * \param num_qo_heads A integer indicates the number of heads of query and output - * \param rotary_mode The rotary mode + * \param pos_encoding_mode The positional encoding mode * \param rope_scale The scaling ratio used in RoPE Interpolation. * \param rope_theta A floating point number indicate the "theta" used in RoPE * \param stream The cuda stream to launch the kernel @@ -1227,10 +1231,9 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( template cudaError_t BatchDecodeWithPagedKVCache( - DTypeIn* q, IdType* q_rope_position, - paged_kv_t paged_kv, + DTypeIn* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse, - uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, + uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { const uint32_t num_kv_heads = paged_kv.num_heads; @@ -1247,19 +1250,19 @@ cudaError_t BatchDecodeWithPagedKVCache( DISPATCH_GQA_GROUP_SIZE( num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_HEAD_DIM( - head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { + head_dim, HEAD_DIM, {DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, { return BatchDecodeWithPagedKVCacheDispatched( - q, q_rope_position, paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale, + kv_layout, POS_ENCODING_MODE, DTypeIn, + DTypeOut, IdType>( + q, q_offset, paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale, rope_theta, stream); })})}); return cudaSuccess; } -template +template cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, @@ -1282,7 +1285,7 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DType dim3 nblks(batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); - auto kernel = BatchDecodeWithPaddedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -1301,15 +1304,16 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DType } template -cudaError_t BatchDecodeWithPaddedKVCache( - DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o, DTypeOut* tmp, float* lse, uint32_t batch_size, - uint32_t padded_kv_len, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, - QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, - std::optional sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, - cudaStream_t stream = nullptr) { - if (!sm_scale.has_value()) { - sm_scale = 1.f / std::sqrt(float(head_dim)); - } +cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o, + DTypeOut* tmp, float* lse, uint32_t batch_size, + uint32_t padded_kv_len, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim, + QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, float rope_theta = 1e4, + cudaStream_t stream = nullptr) { + const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads " @@ -1321,10 +1325,10 @@ cudaError_t BatchDecodeWithPaddedKVCache( num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, - {DISPATCH_ROTARY_MODE( - rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { + {DISPATCH_POS_ENCODING_MODE( + pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { return BatchDecodeWithPaddedKVCacheDispatched( + POS_ENCODING_MODE, DTypeIn, DTypeOut>( q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, stream); })})})}); diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index eaa1b68f8..61c7e8f51 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -21,7 +21,7 @@ #include #include -#include "../rope.cuh" +#include "../pos_enc.cuh" #include "../utils.cuh" #include "decode.cuh" @@ -81,7 +81,7 @@ class BatchDecodeHandler { cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, - RotaryMode rotary_mode) { + PosEncodingMode pos_encoding_mode) { batch_size_before_partition_ = batch_size; uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size; auto work_estimation_func = @@ -89,7 +89,7 @@ class BatchDecodeHandler { IdType>; FLASHINFER_CUDA_CALL(work_estimation_func( tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr, - num_qo_heads, num_kv_heads, head_dim, page_size, rotary_mode, stream_)); + num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode, stream_)); batch_size_after_partition_ = new_batch_size; if (tmp_size > 0) { AlignedAlloactor allocator(buffer, workspace_size_in_bytes); diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 4a17d066f..2b93adc18 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -32,7 +32,7 @@ #include "../mma.cuh" #include "../page.cuh" #include "../permuted_smem.cuh" -#include "../rope.cuh" +#include "../pos_enc.cuh" #include "../utils.cuh" #include "cascade.cuh" #include "state.cuh" @@ -109,10 +109,10 @@ template (q_rope_position[offset]), - static_cast(q_rope_position[offset + (8 / group_size)])}; + float pos[2] = {static_cast(q_offset[offset]), + static_cast(q_offset[offset + (8 / group_size)])}; #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { float cos, sin, tmp; @@ -350,8 +350,8 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_multiply_sm_scale( template __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( - const uint32_t q_idx_base, const IdType* q_rope_position, smem_t* q_smem, - uint32_t* q_smem_offset_r, float (*rope_freq)[4], const float sm_scale) { + const uint32_t q_idx_base, const IdType* q_offset, smem_t* q_smem, uint32_t* q_smem_offset_r, + float (*rope_freq)[4], const float sm_scale) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); const uint32_t tx = threadIdx.x; @@ -368,8 +368,8 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos_multiply_sm q_smem->advance_offset_by_column(q_smem_offset_r_first_half, 0); q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); frag_apply_llama_rope_with_pos( - (DTypeIn*)q_frag_local[0], (DTypeIn*)q_frag_local[1], rope_freq[fyi], q_idx, - q_rope_position, sm_scale); + (DTypeIn*)q_frag_local[0], (DTypeIn*)q_frag_local[1], rope_freq[fyi], q_idx, q_offset, + sm_scale); q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = @@ -522,6 +522,28 @@ __device__ __forceinline__ void compute_qk(smem_t* q_smem, uint32_t* q_smem_offs *k_smem_offset_r -= num_frags_y * 2; } +template +__device__ __forceinline__ void apply_alibi_bias(const uint32_t qo_idx_base, + const uint32_t kv_idx_base, const int32_t q_offset, + float (*alibi_slope)[2], + T (*s_frag)[num_frags_z][8]) { + const int32_t tx = threadIdx.x; +#pragma unroll + for (int32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (int32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (int32_t reg_id = 0; reg_id < 8; ++reg_id) { + const int32_t q_idx = + qo_idx_base + (fx * 16 + tx / 4 + 8 * ((reg_id % 4) / 2)) / group_size, + kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + 8 * (reg_id / 4) + reg_id % 2; + s_frag[fx][fz][reg_id] += + T(alibi_slope[fx][(reg_id % 4) / 2]) * T(kv_idx - q_idx - q_offset); + } + } + } +} + template __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_t kv_idx_base, @@ -841,7 +863,7 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] * \tparam group_size The number of qo heads that maps to a kv head (used in GQA). * \tparam causal Whether to use causal attention. * \tparam kv_layout The layout of the input tensor. - * \tparam rotary_mode The rotary mode. + * \tparam pos_encoding_mode The positional encoding mode. * \tparam num_frags_x The number of fragments in x dimension. * \tparam num_frags_y The number of fragments in y dimension. * \tparam num_frags_z The number of fragments in z dimension. @@ -862,8 +884,9 @@ __device__ __forceinline__ void write_o_reg_gmem(float (*o_frag)[num_frags_y][8] * used in RoPE. */ template + PosEncodingMode pos_encoding_mode, uint32_t num_frags_x, uint32_t num_frags_y, + uint32_t num_frags_z, uint32_t num_warps, typename DTypeIn, typename DTypeQKAccum, + typename DTypeOut> __global__ void SinglePrefillWithKVCacheKernel( DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, DTypeOut* __restrict__ o, void* __restrict__ tmp, float* __restrict__ lse, @@ -876,6 +899,20 @@ __global__ void SinglePrefillWithKVCacheKernel( const uint32_t kv_len = qkv_info.kv_len; const uint32_t tx = threadIdx.x, ty = threadIdx.y; const uint32_t bx = blockIdx.x, chunk_idx = blockIdx.y, kv_head_idx = blockIdx.z; + float alibi_slopes[num_frags_x][2]; + if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_head_idx = + kv_head_idx * group_size + (tx / 4 + j * 8 + fx * 16) % group_size; + const uint32_t num_qo_heads = qkv_info.get_num_qo_heads(); + alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; + } + } + } + const uint32_t num_chunks = gridDim.y; const uint32_t chunk_size = partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; @@ -896,7 +933,7 @@ __global__ void SinglePrefillWithKVCacheKernel( DTypeQKAccum m[num_frags_x][2]; float d[num_frags_x][2]; float rope_freq[num_frags_y / 2][4]; - if constexpr (rotary_mode == RotaryMode::kLlama) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); } init_states(o_frag, m, d); @@ -925,7 +962,7 @@ __global__ void SinglePrefillWithKVCacheKernel( cp_async::wait_group<0>(); block.sync(); - if constexpr (rotary_mode == RotaryMode::kLlama) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { q_smem_inplace_apply_rotary_multiply_sm_scale(qo_idx_base, qo_len, kv_len, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); @@ -973,7 +1010,7 @@ __global__ void SinglePrefillWithKVCacheKernel( cp_async::wait_group<1>(); block.sync(); - if constexpr (rotary_mode == RotaryMode::kLlama) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { k_smem_inplace_apply_rotary( chunk_start + iter * 16 * num_frags_z, &k_smem, &k_smem_offset_r, rope_freq); block.sync(); @@ -983,6 +1020,11 @@ __global__ void SinglePrefillWithKVCacheKernel( compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { + apply_alibi_bias( + qo_idx_base, chunk_start + iter * 16 * num_frags_z, int(kv_len) - int(qo_len), + alibi_slopes, s_frag); + } // apply mask if (iter >= mask_iteration) { mask_s( @@ -1046,13 +1088,13 @@ __global__ void SinglePrefillWithKVCacheKernel( } } -template __global__ void BatchPrefillWithRaggedKVCacheKernel( DTypeIn* __restrict__ q, IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, IdType* __restrict__ qo_indptr, DTypeIn* __restrict__ k, - DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, IdType* __restrict__ q_rope_position, + DTypeIn* __restrict__ v, IdType* __restrict__ kv_indptr, IdType* __restrict__ q_offset, IdType* __restrict__ k_rope_pos_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, float* __restrict__ lse, const uint32_t batch_size, float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { @@ -1070,6 +1112,19 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( kv_len = kv_indptr[request_idx + 1] - kv_indptr[request_idx]; const tensor_info_t qkv_info(qo_len, kv_len, num_kv_heads); + float alibi_slopes[num_frags_x][2]; + if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_head_idx = + kv_head_idx * group_size + (tx / 4 + j * 8 + fx * 16) % group_size; + const uint32_t num_qo_heads = qkv_info.get_num_qo_heads(); + alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; + } + } + } const uint32_t qo_upper_bound = min(qo_len, (tile_idx + 1) * (num_rows_per_cta / group_size)); constexpr bool partition_kv = false; @@ -1087,7 +1142,7 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( float d[num_frags_x][2]; float rope_freq[num_frags_y / 2][4]; - if constexpr (rotary_mode == RotaryMode::kLlama) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); } init_states(o_frag, m, d); @@ -1114,16 +1169,16 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( cp_async::wait_group<0>(); block.sync(); - if constexpr (rotary_mode == RotaryMode::kLlama) { - if (!q_rope_position) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + if (!q_offset) { q_smem_inplace_apply_rotary_multiply_sm_scale(qo_idx_base, qo_len, kv_len, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); } else { q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( - qo_indptr[request_idx] + qo_idx_base, q_rope_position, &qo_smem, &q_smem_offset_r, - rope_freq, sm_scale); + qo_indptr[request_idx] + qo_idx_base, q_offset, &qo_smem, &q_smem_offset_r, rope_freq, + sm_scale); } } else { q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); @@ -1169,7 +1224,7 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( cp_async::wait_group<1>(); block.sync(); - if constexpr (rotary_mode == RotaryMode::kLlama) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { k_smem_inplace_apply_rotary( (k_rope_pos_offset == nullptr ? 0 : k_rope_pos_offset[request_idx]) + iter * 16 * num_frags_z, @@ -1181,6 +1236,11 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { + // TODO(Zihao): handle the case that q_offset is specified + apply_alibi_bias( + qo_idx_base, iter * 16 * num_frags_z, int(kv_len) - int(qo_len), alibi_slopes, s_frag); + } // apply mask if (iter >= mask_iteration) { mask_s( @@ -1235,14 +1295,14 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel( } } -template __global__ void BatchPrefillWithPagedKVCacheKernel( IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices, DTypeIn* __restrict__ q, paged_kv_t paged_kv, - IdType* __restrict__ qo_indptr, IdType* __restrict__ q_rope_position, DTypeOut* __restrict__ o, + IdType* __restrict__ qo_indptr, IdType* __restrict__ q_offset, DTypeOut* __restrict__ o, float* __restrict__ tmp, float* __restrict__ lse, float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); @@ -1252,6 +1312,18 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( const uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y, kv_head_idx = blockIdx.z; const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; + float alibi_slopes[num_frags_x][2]; + if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_head_idx = + kv_head_idx * group_size + (tx / 4 + j * 8 + fx * 16) % group_size; + alibi_slopes[fx][j] = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; + } + } + } const uint32_t request_idx = request_indices[bx], tile_idx = tile_indices[bx]; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; const uint32_t qo_len = qo_indptr[request_idx + 1] - qo_indptr[request_idx], @@ -1276,7 +1348,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( float d[num_frags_x][2]; float rope_freq[num_frags_y / 2][4]; - if constexpr (rotary_mode == RotaryMode::kLlama) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); } init_states(o_frag, m, d); @@ -1301,16 +1373,16 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( cp_async::wait_group<0>(); block.sync(); - if constexpr (rotary_mode == RotaryMode::kLlama) { - if (q_rope_position == nullptr) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + if (q_offset == nullptr) { q_smem_inplace_apply_rotary_multiply_sm_scale(qo_idx_base, qo_len, kv_len, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); } else { q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( - qo_indptr[request_idx] + qo_idx_base, q_rope_position, &qo_smem, &q_smem_offset_r, - rope_freq, sm_scale); + qo_indptr[request_idx] + qo_idx_base, q_offset, &qo_smem, &q_smem_offset_r, rope_freq, + sm_scale); } } else { q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); @@ -1351,7 +1423,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( cp_async::wait_group<1>(); block.sync(); - if constexpr (rotary_mode == RotaryMode::kLlama) { + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { k_smem_inplace_apply_rotary( (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[request_idx]) + iter * 16 * num_frags_z, @@ -1363,6 +1435,11 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { + // TODO(Zihao): handle the case that q_offset is specified + apply_alibi_bias( + qo_idx_base, iter * 16 * num_frags_z, int(kv_len) - int(qo_len), alibi_slopes, s_frag); + } // apply mask if (iter >= mask_iteration) { mask_s( @@ -1433,7 +1510,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( * \param head_dim The dimension of each head. * \param causal Whether to use causal attention. * \param kv_layout The layout of KV Cache. - * \param rotary_mode The rotary mode. + * \param pos_encoding_mode The positional encoding mode. * \param allow_fp16_qk_reduction Whether to allow accumulating q*k^T with fp16. * \param stream The cuda stream to execute the kernel on. * \return status Indicates whether CUDA calls are successful @@ -1442,7 +1519,8 @@ template cudaError_t SinglePrefillWithKVCacheWorkEstimation( uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, bool causal = true, - QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, + QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool allow_fp16_qk_reduction = false, cudaStream_t stream = nullptr) { if (kv_len < qo_len && causal) { std::ostringstream err_msg; @@ -1461,8 +1539,9 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( {DISPATCH_CAUSAL( causal, CAUSAL, {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t num_frags_y = HEAD_DIM / 16; - DISPATCH_ROTARY_MODE( - rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { + DISPATCH_POS_ENCODING_MODE( + pos_encoding_mode, pos_encoding_mode, + {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { using DTypeQKAccum = typename std::conditional::value, @@ -1480,7 +1559,8 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( constexpr uint32_t num_warps = 4UL; const uint32_t max_num_frags_z_reg = (HEAD_DIM == 128 && num_frags_x == 2 && - ROTARY_MODE == RotaryMode::kLlama && !allow_fp16_qk_reduction) + pos_encoding_mode == PosEncodingMode::kRoPELlama && + !allow_fp16_qk_reduction) ? 2 : 4; const uint32_t max_num_frags_z_smem = @@ -1510,8 +1590,8 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( num_frags_x * num_warps * 16; auto partition_kv_kernel = SinglePrefillWithKVCacheKernel< /*partition_kv=*/true, GROUP_SIZE, CAUSAL, KV_LAYOUT, - ROTARY_MODE, num_frags_x, num_frags_y, num_frags_z, num_warps, - DTypeIn, DTypeQKAccum, DTypeOut>; + pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, + num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; tensor_info_t qkv_info( qo_len, kv_len, num_kv_heads); uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * @@ -1560,8 +1640,9 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation( return cudaSuccess; } -template +template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, @@ -1592,7 +1673,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* constexpr uint32_t num_warps = 4UL; const uint32_t max_num_frags_z_reg = - (HEAD_DIM == 128 && num_frags_x == 2 && ROTARY_MODE == RotaryMode::kLlama && + (HEAD_DIM == 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama && !ALLOW_FP16_QK_REDUCTION) ? 2 : 4; @@ -1616,7 +1697,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps * 16; auto partition_kv_kernel = SinglePrefillWithKVCacheKernel; tensor_info_t qkv_info(qo_len, kv_len, num_kv_heads); uint32_t smem_size = @@ -1642,10 +1723,9 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv - auto kernel = - SinglePrefillWithKVCacheKernel; + auto kernel = SinglePrefillWithKVCacheKernel< + /*partition_kv=*/false, GROUP_SIZE, CAUSAL, KV_LAYOUT, pos_encoding_mode, num_frags_x, + num_frags_y, num_frags_z, num_warps, DTypeIn, DTypeQKAccum, DTypeOut>; void* args[] = {(void*)&q, (void*)&k, (void*)&v, @@ -1707,7 +1787,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* * \param head_dim The dimension of each head. * \param causal Whether to use causal attention. * \param kv_layout The layout of input and output. - * \param rotary_mode The rotary mode. + * \param pos_encoding_mode The positional encoding mode. * \param allow_fp16_qk_reduction Whether to allow accumulating q*k^T with fp16. * \param rope_scale The scaling factor used in RoPE interpolation. * \param rope_theta The theta used in RoPE. @@ -1715,12 +1795,15 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* * \return status Indicates whether CUDA calls are successful */ template -cudaError_t SinglePrefillWithKVCache( - DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, bool causal = true, - QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, - bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, - float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { +cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, + float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, + bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + bool allow_fp16_qk_reduction = false, + std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, float rope_theta = 1e4, + cudaStream_t stream = nullptr) { const uint32_t group_size = num_qo_heads / num_kv_heads; const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); DISPATCH_ALLOW_FP16_QK_REDUCTION( @@ -1731,11 +1814,11 @@ cudaError_t SinglePrefillWithKVCache( causal, CAUSAL, {DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, - {DISPATCH_ROTARY_MODE( - rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { + {DISPATCH_POS_ENCODING_MODE( + pos_encoding_mode, pos_encoding_mode, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { SinglePrefillWithKVCacheDispatched( + pos_encoding_mode, + ALLOW_FP16_QK_REDUCTION, CAUSAL>( q, k, v, o, tmp, lse, num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, stream); })})})})})}); @@ -1743,11 +1826,11 @@ cudaError_t SinglePrefillWithKVCache( } template + PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, + typename DTypeIn, typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, - DTypeIn* v, IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, + DTypeIn* v, IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_tiles, const uint32_t num_kv_heads, const float sm_scale, const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr) { @@ -1771,7 +1854,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( const int max_smem_per_threadblock = max_smem_per_sm / 2; const uint32_t max_num_frags_z_reg = - (HEAD_DIM == 128 && num_frags_x == 2 && ROTARY_MODE == RotaryMode::kLlama && + (HEAD_DIM == 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama && !ALLOW_FP16_QK_REDUCTION) ? 2 : 4; @@ -1790,7 +1873,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( throw std::invalid_argument(err_msg.str()); } else { auto kernel = - BatchPrefillWithRaggedKVCacheKernel; uint32_t smem_size = @@ -1804,7 +1887,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (void*)&k, (void*)&v, (void*)&kv_indptr, - (void*)&q_rope_position, + (void*)&q_offset, (void*)&k_rope_pos_offset, (void*)&o, (void*)&tmp, @@ -1821,13 +1904,13 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( template cudaError_t BatchPrefillWithRaggedKVCache( - DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, - IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, - const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads, - const uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, - RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, - std::optional maybe_sm_scale = std::nullopt, const float rope_scale = 1.f, - const float rope_theta = 1e4, cudaStream_t stream = nullptr) { + DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, IdType* q_offset, + IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, + const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim, + bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, + const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) { const uint32_t group_size = num_qo_heads / num_kv_heads; const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); @@ -1861,14 +1944,14 @@ cudaError_t BatchPrefillWithRaggedKVCache( {DISPATCH_CAUSAL( causal, CAUSAL, {DISPATCH_HEAD_DIM( - head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { + head_dim, HEAD_DIM, + {DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, pos_encoding_mode, { return BatchPrefillWithRaggedKVCacheDispatched< - NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE, + NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, KV_LAYOUT, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( q, request_indices_d, tile_indices_d, qo_indptr, k, v, kv_indptr, - q_rope_position, k_rope_pos_offset, o, tmp, lse, batch_size, - num_qo_tiles, num_kv_heads, sm_scale, rope_scale, rope_theta, - stream); + q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, + num_kv_heads, sm_scale, rope_scale, rope_theta, stream); })})})})})})}); FLASHINFER_CUDA_CALL(cudaFreeAsync(request_indices_d, stream)); @@ -1890,7 +1973,7 @@ cudaError_t BatchPrefillWithRaggedKVCache( * \param lse The logsumexp value. * \param num_qo_heads The number of query and output heads. * \param causal Whether to use causal attention. - * \param rotary_mode The rotary mode. + * \param pos_encoding_mode The positional encoding mode. * \param allow_fp16_qk_reduction Whether to allow accumulating q*k^T with fp16. * \param rope_scale The scaling factor used in RoPE interpolation. * \param rope_theta The theta used in RoPE. @@ -1899,14 +1982,13 @@ cudaError_t BatchPrefillWithRaggedKVCache( * \note This implementation executes requests one by one, which is not efficient. */ template + uint32_t HEAD_DIM, PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, + bool CAUSAL, typename DTypeIn, typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithPagedKVCacheFallbackDispatched( - DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, - IdType* q_rope_position, paged_kv_t paged_kv, - DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, - std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, - float rope_theta = 1e4, cudaStream_t stream = nullptr) { + DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, + paged_kv_t paged_kv, DTypeOut* o, float* tmp, + float* lse, uint32_t num_qo_tiles, std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { constexpr QKVLayout KV_LAYOUT = QKVLayout::kNHD; const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; @@ -1930,10 +2012,10 @@ cudaError_t BatchPrefillWithPagedKVCacheFallbackDispatched( cudaMemcpyHostToDevice, stream)); FLASHINFER_CUDA_CALL(PagedKVCacheToRaggedTensor(paged_kv, keys, values, kv_indptr, stream)); - BatchPrefillWithRaggedKVCacheDispatched( - q, request_indices, tile_indices, qo_indptr, keys, values, kv_indptr, q_rope_position, + BatchPrefillWithRaggedKVCacheDispatched( + q, request_indices, tile_indices, qo_indptr, keys, values, kv_indptr, q_offset, paged_kv.rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); @@ -1945,14 +2027,14 @@ cudaError_t BatchPrefillWithPagedKVCacheFallbackDispatched( } template cudaError_t BatchPrefillWithPagedKVCacheDispatched( - DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, - IdType* q_rope_position, paged_kv_t paged_kv, - DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream) { + DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset, + paged_kv_t paged_kv, DTypeOut* o, float* tmp, + float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_warps = 4; @@ -1976,7 +2058,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( const int max_smem_per_threadblock = max_smem_per_sm / 2; const uint32_t max_num_frags_z_reg = - (HEAD_DIM == 128 && num_frags_x == 2 && ROTARY_MODE == RotaryMode::kLlama && + (HEAD_DIM == 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama && !ALLOW_FP16_QK_REDUCTION) ? 2 : 4; @@ -1995,7 +2077,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( throw std::invalid_argument(err_msg.str()); } else { auto kernel = BatchPrefillWithPagedKVCacheKernel< - GROUP_SIZE, PAGE_SIZE, CAUSAL, ROTARY_MODE, num_frags_x, num_frags_y, num_frags_z, + GROUP_SIZE, PAGE_SIZE, CAUSAL, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>; uint32_t smem_size = (num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn); @@ -2006,7 +2088,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( (void*)&q, (void*)&paged_kv, (void*)&qo_indptr, - (void*)&q_rope_position, + (void*)&q_offset, (void*)&o, (void*)&tmp, (void*)&lse, @@ -2022,12 +2104,12 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( template cudaError_t BatchPrefillWithPagedKVCache( - DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, + DTypeIn* q, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, bool causal = true, - RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, - std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, - float rope_theta = 1e4, cudaStream_t stream = nullptr) { + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; const uint32_t batch_size = paged_kv.batch_size; @@ -2063,27 +2145,27 @@ cudaError_t BatchPrefillWithPagedKVCache( causal, CAUSAL, {DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, - {DISPATCH_ROTARY_MODE( - rotary_mode, ROTARY_MODE, + {DISPATCH_POS_ENCODING_MODE( + pos_encoding_mode, pos_encoding_mode, {DISPATCH_PAGE_SIZE( paged_kv.page_size, PAGE_SIZE, { if constexpr (PAGE_SIZE == 0) { return BatchPrefillWithPagedKVCacheFallbackDispatched< page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, - ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, + pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(q, request_indices_d, tile_indices_d, - qo_indptr, q_rope_position, paged_kv, o, - tmp, lse, num_qo_tiles, sm_scale, - rope_scale, rope_theta, stream); + qo_indptr, q_offset, paged_kv, o, tmp, lse, + num_qo_tiles, sm_scale, rope_scale, + rope_theta, stream); } else { return BatchPrefillWithPagedKVCacheDispatched< page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, - HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, + HEAD_DIM, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - q, request_indices_d, tile_indices_d, qo_indptr, - q_rope_position, paged_kv, o, tmp, lse, num_qo_tiles, - sm_scale, rope_scale, rope_theta, stream); + q, request_indices_d, tile_indices_d, qo_indptr, q_offset, + paged_kv, o, tmp, lse, num_qo_tiles, sm_scale, rope_scale, + rope_theta, stream); } }) diff --git a/include/flashinfer/attention/wrapper.cuh b/include/flashinfer/attention/wrapper.cuh index 73d0fc739..bf4dd2efd 100644 --- a/include/flashinfer/attention/wrapper.cuh +++ b/include/flashinfer/attention/wrapper.cuh @@ -36,7 +36,7 @@ namespace flashinfer { * \param o The output tensor. * \param lse The logsumexp values. * \param num_qo_heads The number of heads. - * \param rotary_mode The rotary mode. + * \param pos_encoding_mode The positional encoding mode. * \param rope_scale The scale of rope. * \param rope_theta The theta of rope. * \param stream The CUDA stream. @@ -46,9 +46,9 @@ namespace flashinfer { template cudaError_t BatchDecodeWithPagedKVCacheWrapper( - BatchDecodeHandler* handler, DTypeIn* q, IdType* q_rope_position, + BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* lse, - uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, + uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { paged_kv_t new_paged_kv = paged_kv; @@ -73,15 +73,15 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( throw std::runtime_error(err_msg.str()); } return BatchDecodeWithPagedKVCache( - q, q_rope_position, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, + q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, pos_encoding_mode, maybe_sm_scale, rope_scale, rope_theta, stream); } template + PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, + typename DTypeIn, typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, + BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { float* tmp = nullptr; @@ -105,15 +105,15 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( num_frags_x, NUM_FRAGS_X, {DISPATCH_PAGE_SIZE(paged_kv.page_size, PAGE_SIZE, { if constexpr (PAGE_SIZE == 0) { return BatchPrefillWithPagedKVCacheFallbackDispatched< - page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, + page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse, + q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, o, tmp, lse, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); } else { return BatchPrefillWithPagedKVCacheDispatched< - page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, - ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse, + page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, + pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( + q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, o, tmp, lse, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); } })}); @@ -123,9 +123,10 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( template cudaError_t BatchPrefillWithPagedKVCacheWrapper( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, + BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* lse, - uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone, + uint32_t num_qo_heads, bool causal = true, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim))); @@ -137,25 +138,25 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( head_dim, HEAD_DIM, {DISPATCH_CAUSAL( causal, CAUSAL, - {DISPATCH_ROTARY_MODE( - rotary_mode, ROTARY_MODE, + {DISPATCH_POS_ENCODING_MODE( + pos_encoding_mode, pos_encoding_mode, {DISPATCH_ALLOW_FP16_QK_REDUCTION( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { return BatchPrefillWithPagedKVCacheWrapperDispatched< - page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, + page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, sm_scale, - rope_scale, rope_theta, stream); + handler, q, qo_indptr, q_offset, paged_kv, o, lse, sm_scale, rope_scale, + rope_theta, stream); })})})})}); return cudaSuccess; } -template +template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, - IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, + IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, const uint32_t batch_size, const uint32_t num_kv_heads, const float sm_scale, const float rope_scale, const float rope_theta, cudaStream_t stream) { float* tmp = nullptr; @@ -177,11 +178,11 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, { return BatchPrefillWithRaggedKVCacheDispatched( - q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_rope_position, - k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, - rope_scale, rope_theta, stream); + pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, + CAUSAL, DTypeIn, DTypeOut, IdType>( + q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_offset, k_rope_pos_offset, + o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, rope_scale, rope_theta, + stream); }); return cudaSuccess; } @@ -192,9 +193,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, - RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, - std::optional maybe_sm_scale = std::nullopt, const float rope_scale = 1.f, - const float rope_theta = 1e4, cudaStream_t stream = nullptr) { + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, + const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) { const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); DISPATCH_LAYOUT( kv_layout, KV_LAYOUT, @@ -204,14 +205,14 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( head_dim, HEAD_DIM, {DISPATCH_CAUSAL( causal, CAUSAL, - {DISPATCH_ROTARY_MODE( - rotary_mode, ROTARY_MODE, + {DISPATCH_POS_ENCODING_MODE( + pos_encoding_mode, pos_encoding_mode, {DISPATCH_ALLOW_FP16_QK_REDUCTION( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { return BatchPrefillWithRaggedKVCacheWrapperDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE, + GROUP_SIZE, HEAD_DIM, KV_LAYOUT, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, k, v, kv_indptr, /*q_rope_position=*/nullptr, + handler, q, qo_indptr, k, v, kv_indptr, /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, o, lse, batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); })})})})})}); diff --git a/include/flashinfer/rope.cuh b/include/flashinfer/pos_enc.cuh similarity index 88% rename from include/flashinfer/rope.cuh rename to include/flashinfer/pos_enc.cuh index 3c794aa17..64d621f27 100644 --- a/include/flashinfer/rope.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef FLASHINFER_ROPE_CUH_ -#define FLASHINFER_ROPE_CUH_ +#ifndef FLASHINFER_POS_ENC_CUH_ +#define FLASHINFER_POS_ENC_CUH_ #include #include "layout.cuh" +#include "math.cuh" #include "utils.cuh" #include "vec_dtypes.cuh" @@ -28,28 +29,37 @@ namespace flashinfer { * \brief An enumeration class that defines different modes for applying RoPE * (Rotary Positional Embeddings). */ -enum class RotaryMode { +enum class PosEncodingMode { // No rotary positional embeddings kNone = 0U, // Apply Llama-style rope. - kLlama = 1U, + kRoPELlama = 1U, + // Apply ALiBi bias + kALiBi = 2U }; /*! - * \brief Convert RotaryMode to string - * \param rotary_mode A RotaryMode value + * \brief Convert PosEncodingMode to string + * \param pos_encoding_mode A PosEncodingMode value */ -inline std::string RotaryModeToString(const RotaryMode& rotary_mode) { - switch (rotary_mode) { - case RotaryMode::kNone: +inline std::string PosEncodingModeToString(const PosEncodingMode& pos_encoding_mode) { + switch (pos_encoding_mode) { + case PosEncodingMode::kNone: return "None"; - case RotaryMode::kLlama: + case PosEncodingMode::kRoPELlama: return "Llama"; + case PosEncodingMode::kALiBi: + return "ALiBi"; default: return "Unknown"; } } +__device__ __forceinline__ float get_alibi_slope(uint32_t head_idx, uint32_t num_heads) { + // NOTE(Zihao): here we assume that num_heads is a power of 2 + return math::ptx_exp2(-8. * float(head_idx + 1) / float(num_heads)); +} + /*! * \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim], * return thread-local vector @@ -63,7 +73,7 @@ inline std::string RotaryModeToString(const RotaryMode& rotary_mode) { */ template __device__ __forceinline__ vec_t vec_apply_llama_rope( - const T* x, const vec_t& freq, uint32_t offset) { + const T* x, const vec_t& freq, int32_t offset) { constexpr uint32_t head_dim = vec_size * bdx; vec_t permuted_vec, vec; vec.cast_load(x + threadIdx.x * vec_size); @@ -170,4 +180,4 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ } // namespace flashinfer -#endif // FLASHINFER_ROPE_CUH_ +#endif // FLASHINFER_POS_ENC_CUH_ diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 1c778c670..d1ffa5f93 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -178,23 +178,28 @@ } \ } -#define DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, ...) \ - switch (rotary_mode) { \ - case RotaryMode::kNone: { \ - constexpr RotaryMode ROTARY_MODE = RotaryMode::kNone; \ - __VA_ARGS__ \ - break; \ - } \ - case RotaryMode::kLlama: { \ - constexpr RotaryMode ROTARY_MODE = RotaryMode::kLlama; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported rotary_mode: " << int(rotary_mode); \ - throw std::invalid_argument(err_msg.str()); \ - } \ +#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ + switch (pos_encoding_mode) { \ + case PosEncodingMode::kNone: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kRoPELlama: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kRoPELlama; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kALiBi: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \ + throw std::invalid_argument(err_msg.str()); \ + } \ } namespace flashinfer { diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index b5b65ae42..271fd0a64 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -22,7 +22,8 @@ using namespace flashinfer; std::vector batch_decode_with_padded_kv_cache( torch::Tensor q, torch::Tensor k_padded, torch::Tensor v_padded, unsigned int layout, - unsigned int rotary_mode, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + unsigned int pos_encoding_mode, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(k_padded); CHECK_INPUT(v_padded); @@ -59,8 +60,8 @@ std::vector batch_decode_with_padded_kv_cache( static_cast(v_padded.data_ptr()), static_cast(o.data_ptr()), /*tmp=*/tmp, /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, batch_size, - padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout, RotaryMode(rotary_mode), - rope_scale, rope_theta, torch_current_stream); + padded_kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout, + PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPaddedKVCache failed with error code ", status); return true; @@ -77,7 +78,7 @@ std::vector batch_decode_with_padded_kv_cache( void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int head_dim, unsigned int page_size, unsigned int rotary_mode, + unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode, torch::Tensor empty_data) { // NOTE(zihao): not necessary to be CUDA tensor CHECK_CONTIGUOUS(indptr); @@ -99,7 +100,7 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, static_cast(indptr.data_ptr()), static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, RotaryMode(rotary_mode)); + num_kv_heads, head_dim, page_size, PosEncodingMode(pos_encoding_mode)); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", cudaGetErrorString(status)); return true; @@ -114,8 +115,9 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward() { handler_.EndForwa std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, - torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, unsigned int rotary_mode, - float rope_scale, float rope_theta, bool return_lse) { + torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, + unsigned int pos_encoding_mode, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(paged_kv_data); CHECK_INPUT(paged_kv_indptr); @@ -164,10 +166,11 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( static_cast(paged_kv_last_page_len.data_ptr())); cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( - &handler_, static_cast(q.data_ptr()), /*q_rope_position=*/nullptr, paged_kv, + &handler_, static_cast(q.data_ptr()), /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), num_qo_heads, - RotaryMode(rotary_mode), rope_scale, rope_theta, /*stream=*/torch_current_stream); + PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", cudaGetErrorString(status)); }); diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 3c0073b38..f43c0f795 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -48,8 +48,9 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_.EndForw std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, bool causal, unsigned int rotary_mode, - bool allow_fp16_qk_reduction, float rope_scale, float rope_theta, bool return_lse) { + torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(paged_kv_data); @@ -107,15 +108,15 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( return DISPATCH_head_dim(head_dim, [&] { DISPATCH_CAUSAL(causal, CAUSAL, { DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { - DISPATCH_ROTARY_MODE(RotaryMode(rotary_mode), ROTARY_MODE, { + DISPATCH_POS_ENCODING_MODE(PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, { cudaError_t status = BatchPrefillWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, + PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, c_type, c_type, int32_t>( &handler_, static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), - /*q_rope_position=*/nullptr, paged_kv, static_cast(o.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, rope_scale, - rope_theta, + /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, sm_scale, + rope_scale, rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error code ", @@ -167,8 +168,9 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_.EndFor std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, - torch::Tensor kv_indptr, bool causal, unsigned int rotary_mode, bool allow_fp16_qk_reduction, - float rope_scale, float rope_theta, bool return_lse) { + torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(k); @@ -206,18 +208,18 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( return DISPATCH_head_dim(head_dim, [&] { DISPATCH_CAUSAL(causal, CAUSAL, { DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { - DISPATCH_ROTARY_MODE(RotaryMode(rotary_mode), ROTARY_MODE, { + DISPATCH_POS_ENCODING_MODE(PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, { DISPATCH_LAYOUT(kv_layout_, KV_LAYOUT, { cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, - c_type, c_type, int32_t>( + GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, + CAUSAL, c_type, c_type, int32_t>( &handler_, static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(kv_indptr.data_ptr()), - /*q_rope_position=*/nullptr, /*k_rope_pos_offset=*/nullptr, + /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, batch_size, - num_kv_heads, rope_scale, rope_theta, + num_kv_heads, sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", diff --git a/python/csrc/flashinfer_decl.h b/python/csrc/flashinfer_decl.h index 359d8d9ff..e373b199c 100644 --- a/python/csrc/flashinfer_decl.h +++ b/python/csrc/flashinfer_decl.h @@ -16,36 +16,35 @@ #pragma once #include #include -#include +#include -#define INST_BatchPrefillPagedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \ - LAYOUT, ROTARY_MODE) \ - namespace flashinfer { \ - template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched< \ - PageStorage::kIndices, LAYOUT, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, \ - CAUSAL, T, T, int32_t>(BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, \ - int32_t* q_rope_position, \ - paged_kv_t paged_kv, T* o, \ - float* lse, float sm_scale, float rope_scale, float rope_theta, \ - cudaStream_t stream); \ +#define INST_BatchPrefillPagedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \ + LAYOUT, pos_encoding_mode) \ + namespace flashinfer { \ + template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched< \ + PageStorage::kIndices, LAYOUT, GROUP_SIZE, HEAD_DIM, pos_encoding_mode, \ + ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \ + BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, int32_t* q_offset, \ + paged_kv_t paged_kv, T* o, float* lse, \ + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); \ } -#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \ - LAYOUT, ROTARY_MODE) \ - namespace flashinfer { \ - template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \ - GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \ - BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \ - int32_t* q_rope_position, int32_t* k_rope_pos_offset, T* o, float* lse, uint32_t batch_size, \ - uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, \ - cudaStream_t stream); \ +#define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \ + LAYOUT, pos_encoding_mode) \ + namespace flashinfer { \ + template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \ + GROUP_SIZE, HEAD_DIM, LAYOUT, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, \ + int32_t>(BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, \ + int32_t* kv_indptr, int32_t* q_offset, int32_t* k_rope_pos_offset, T* o, \ + float* lse, uint32_t batch_size, uint32_t num_kv_heads, float sm_scale, \ + float rope_scale, float rope_theta, cudaStream_t stream); \ } #define INST_SinglePrefill(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, LAYOUT, \ - ROTARY_MODE) \ + pos_encoding_mode) \ namespace flashinfer { \ template cudaError_t SinglePrefillWithKVCacheDispatched< \ - GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T>( \ + GROUP_SIZE, HEAD_DIM, LAYOUT, pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T>( \ T * q, T* k, T* v, T* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, \ uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); \ } @@ -54,25 +53,26 @@ namespace flashinfer { class BatchPrefillHandler; -template +template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, - IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, + IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, const uint32_t batch_size, const uint32_t num_kv_heads, const float sm_scale, const float rope_scale, const float rope_theta, cudaStream_t stream); template + PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, + typename DTypeIn, typename DTypeOut, typename IdType> cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, + BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template +template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, float sm_scale, diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index b8060ce19..feedd7524 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -25,14 +25,14 @@ class BatchDecodeHandler; } // namespace flashinfer torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, unsigned int rotary_mode, + torch::Tensor tmp, unsigned int pos_encoding_mode, unsigned int layout, float sm_scale, float rope_scale, float rope_theta); std::vector single_prefill_with_kv_cache( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, - unsigned int layout, unsigned int rotary_mode, bool allow_fp16_qk_reduction, float sm_scale, - float rope_scale, float rope_theta, bool return_lse); + unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, + float sm_scale, float rope_scale, float rope_theta, bool return_lse); void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, torch::Tensor append_indptr, torch::Tensor kv_data, @@ -49,7 +49,8 @@ std::vector merge_states(torch::Tensor v, torch::Tensor s); std::vector batch_decode_with_padded_kv_cache( torch::Tensor q, torch::Tensor k_padded, torch::Tensor v_padded, unsigned int layout, - unsigned int rotary_mode, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + unsigned int pos_encoding_mode, float sm_scale, float rope_scale, float rope_theta, + bool return_lse); class BatchDecodeWithPagedKVCachePyTorchWrapper { public: @@ -59,13 +60,13 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, - unsigned int rotary_mode, torch::Tensor empty_data); + unsigned int pos_encoding_mode, torch::Tensor empty_data); void EndForward(); std::vector Forward(torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, unsigned int rotary_mode, - float sm_scale, float rope_scale, float rope_theta, - bool return_lse); + torch::Tensor paged_kv_last_page_len, + unsigned int pos_encoding_mode, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); private: BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout) @@ -87,7 +88,7 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, bool causal, - unsigned int rotary_mode, bool allow_fp16_qk_reduction, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); @@ -109,7 +110,7 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { void EndForward(); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, bool causal, - unsigned int rotary_mode, bool allow_fp16_qk_reduction, + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu index 5f7707a39..1c8991581 100644 --- a/python/csrc/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -21,7 +21,7 @@ using namespace flashinfer; torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, unsigned int rotary_mode, + torch::Tensor tmp, unsigned int pos_encoding_mode, unsigned int layout, float sm_scale, float rope_scale, float rope_theta) { CHECK_INPUT(q); @@ -51,7 +51,8 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, head_dim, - kv_layout, RotaryMode(rotary_mode), rope_scale, rope_theta, torch_current_stream); + kv_layout, PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, + torch_current_stream); TORCH_CHECK(status == cudaSuccess, "SingleDecodeWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); return true; diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 8e194ac64..e8d4f7b51 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -21,8 +21,8 @@ using namespace flashinfer; std::vector single_prefill_with_kv_cache( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, - unsigned int layout, unsigned int rotary_mode, bool allow_fp16_qk_reduction, float rope_scale, - float rope_theta, bool return_lse) { + unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, + float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(k); CHECK_INPUT(v); @@ -59,15 +59,17 @@ std::vector single_prefill_with_kv_cache( DISPATCH_CAUSAL(causal, CAUSAL, { DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, { - DISPATCH_ROTARY_MODE(RotaryMode(rotary_mode), ROTARY_MODE, { + DISPATCH_POS_ENCODING_MODE(PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, { cudaError_t status = - SinglePrefillWithKVCacheDispatched( + SinglePrefillWithKVCacheDispatched( static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_kv_heads, qo_len, kv_len, rope_scale, rope_theta, torch_current_stream); + num_kv_heads, qo_len, kv_len, sm_scale, rope_scale, rope_theta, + torch_current_stream); TORCH_CHECK(status == cudaSuccess, "SinglePrefillWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 279222378..3054c04ad 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -40,9 +40,9 @@ ) from .utils import ( expand_5d, - check_rotary_mode, + check_pos_encoding_mode, check_kv_layout, - RotaryMode, + PosEncodingMode, TensorLayout, ) @@ -269,7 +269,7 @@ def batch_decode_with_shared_prefix_padded_kv_cache( k_shared, v_shared, causal=False, - rotary_mode="NONE", + pos_encoding_mode="NONE", kv_layout=kv_layout, allow_fp16_qk_reduction=allow_fp16_qk_reduction, sm_scale=sm_scale, @@ -281,7 +281,7 @@ def batch_decode_with_shared_prefix_padded_kv_cache( k_unique, v_unique, kv_layout=kv_layout, - rotary_mode="NONE", + pos_encoding_mode="NONE", sm_scale=sm_scale, rope_scale=rope_scale, rope_theta=rope_theta, @@ -419,7 +419,7 @@ def begin_forward( The dimension of the heads page_size : int The page size of the paged kv cache - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). data_type : Union[str, torch.dtype] @@ -444,7 +444,7 @@ def begin_forward( num_kv_heads, head_dim, page_size, - rotary_mode="NONE", + pos_encoding_mode="NONE", data_type=data_type, ) @@ -508,7 +508,7 @@ def forward( k_shared, v_shared, causal=False, - rotary_mode="NONE", + pos_encoding_mode="NONE", kv_layout=self._kv_layout, allow_fp16_qk_reduction=allow_fp16_qk_reduction, sm_scale=sm_scale, @@ -518,7 +518,7 @@ def forward( V_unique, S_unique = self._batch_decode_wrapper.forward_return_lse( q, unique_kv_data, - rotary_mode="NONE", + pos_encoding_mode="NONE", sm_scale=sm_scale, rope_scale=rope_scale, rope_theta=rope_theta, @@ -755,7 +755,7 @@ def forward( k_shared, v_shared, causal=False, - rotary_mode="NONE", + pos_encoding_mode="NONE", kv_layout=self._kv_layout, allow_fp16_qk_reduction=allow_fp16_qk_reduction, sm_scale=sm_scale, @@ -766,7 +766,7 @@ def forward( q, unique_kv_data, causal=causal, - rotary_mode="NONE", + pos_encoding_mode="NONE", allow_fp16_qk_reduction=allow_fp16_qk_reduction, sm_scale=sm_scale, rope_scale=rope_scale, diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 0742edff8..66ab7ed12 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -32,10 +32,10 @@ from .utils import ( - RotaryMode, + PosEncodingMode, TensorLayout, expand_5d, - check_rotary_mode, + check_pos_encoding_mode, check_kv_layout, ) @@ -56,7 +56,7 @@ def single_decode_with_kv_cache( k: torch.Tensor, v: torch.Tensor, kv_layout: str = "NHD", - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, @@ -77,7 +77,7 @@ def single_decode_with_kv_cache( :attr:`kv_layout` is ``HND``. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). sm_scale : Optional[float] @@ -114,7 +114,7 @@ def single_decode_with_kv_cache( not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ - check_rotary_mode(rotary_mode) + check_pos_encoding_mode(pos_encoding_mode) check_kv_layout(kv_layout) tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 8 * 1024 * 1024, q.device) if sm_scale is None: @@ -129,8 +129,8 @@ def single_decode_with_kv_cache( k, v, tmp, - getattr(RotaryMode, rotary_mode), - getattr(TensorLayout, kv_layout), + PosEncodingMode[pos_encoding_mode].value, + TensorLayout[kv_layout].value, sm_scale, rope_scale, rope_theta, @@ -142,7 +142,7 @@ def batch_decode_with_padded_kv_cache( k_padded: torch.Tensor, v_padded: torch.Tensor, kv_layout: str = "NHD", - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, @@ -166,7 +166,7 @@ def batch_decode_with_padded_kv_cache( :attr:`kv_layout` is ``HND``. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). sm_scale : Optional[float] @@ -216,8 +216,8 @@ def batch_decode_with_padded_kv_cache( q, k_padded, v_padded, - getattr(TensorLayout, kv_layout), - getattr(RotaryMode, rotary_mode), + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, sm_scale, rope_scale, rope_theta, @@ -230,7 +230,7 @@ def batch_decode_with_padded_kv_cache_return_lse( k_padded: torch.Tensor, v_padded: torch.Tensor, kv_layout: str = "NHD", - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, @@ -255,7 +255,7 @@ def batch_decode_with_padded_kv_cache_return_lse( :attr:`kv_layout` is ``HND``. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). sm_scale : Optional[float] @@ -312,8 +312,8 @@ def batch_decode_with_padded_kv_cache_return_lse( q, k_padded, v_padded, - getattr(TensorLayout, kv_layout), - getattr(RotaryMode, rotary_mode), + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, sm_scale, rope_scale, rope_theta, @@ -365,7 +365,7 @@ class BatchDecodeWithPagedKVCacheWrapper: ... num_kv_heads, ... head_dim, ... page_size, - ... rotary_mode="NONE", + ... pos_encoding_mode="NONE", ... data_type=torch.float16 ... ) >>> outputs = [] @@ -405,7 +405,7 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer self._wrapper = _kernels.BatchDecodeWithPagedKVCachePyTorchWrapper( - getattr(TensorLayout, kv_layout) + TensorLayout[kv_layout].value ) self._paged_kv_indptr = None self._paged_kv_indices = None @@ -431,7 +431,7 @@ def begin_forward( num_kv_heads: int, head_dim: int, page_size: int, - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", data_type: Union[str, torch.dtype] = "float16", ): r"""Create auxiliary data structures for batch decode for multiple forward calls @@ -454,7 +454,7 @@ def begin_forward( The dimension of the heads page_size : int The page size of the paged kv cache - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). data_type : Union[str, torch.dtype] @@ -491,7 +491,7 @@ def begin_forward( num_kv_heads, head_dim, page_size, - getattr(RotaryMode, rotary_mode), + PosEncodingMode[pos_encoding_mode].value, empty_data, ) @@ -506,7 +506,7 @@ def forward( self, q: torch.Tensor, paged_kv_data: torch.Tensor, - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, @@ -523,7 +523,7 @@ def forward( :attr:`kv_layout` is ``NHD``, or ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). sm_scale : Optional[float] @@ -539,7 +539,7 @@ def forward( torch.Tensor The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. """ - check_rotary_mode(rotary_mode) + check_pos_encoding_mode(pos_encoding_mode) if sm_scale is None: head_dim = q.shape[-1] sm_scale = 1.0 / math.sqrt(head_dim) @@ -555,7 +555,7 @@ def forward( self._paged_kv_indptr, self._paged_kv_indices, self._paged_kv_last_page_len, - getattr(RotaryMode, rotary_mode), + PosEncodingMode[pos_encoding_mode].value, sm_scale, rope_scale, rope_theta, @@ -566,7 +566,7 @@ def forward_return_lse( self, q: torch.Tensor, paged_kv_data: torch.Tensor, - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, @@ -584,7 +584,7 @@ def forward_return_lse( :attr:`kv_layout` is ``NHD``, or ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). sm_scale : Optional[float] @@ -607,7 +607,7 @@ def forward_return_lse( Please refer to the :ref:`tutorial ` for a detailed explanation of the log-sum-exp function and attention states. """ - check_rotary_mode(rotary_mode) + check_pos_encoding_mode(pos_encoding_mode) if sm_scale is None: head_dim = q.shape[-1] sm_scale = 1.0 / math.sqrt(head_dim) @@ -622,7 +622,7 @@ def forward_return_lse( self._paged_kv_indptr, self._paged_kv_indices, self._paged_kv_last_page_len, - getattr(RotaryMode, rotary_mode), + PosEncodingMode[pos_encoding_mode].value, sm_scale, rope_scale, rope_theta, diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index 112ed30a7..6c7c7558c 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -126,5 +126,5 @@ def append_paged_kv_cache( kv_indices, kv_indptr, kv_last_page_len, - getattr(TensorLayout, kv_layout), + TensorLayout[kv_layout].value, ) diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index b5c082da5..bf780c847 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -31,10 +31,10 @@ raise e from .utils import ( - RotaryMode, + PosEncodingMode, TensorLayout, expand_5d, - check_rotary_mode, + check_pos_encoding_mode, check_kv_layout, ) @@ -66,7 +66,7 @@ def single_prefill_with_kv_cache( v: torch.Tensor, causal: bool = False, kv_layout: str = "NHD", - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -91,7 +91,7 @@ def single_prefill_with_kv_cache( Whether to apply causal mask to the attention matrix. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). allow_fp16_qk_reduction : bool @@ -133,7 +133,7 @@ def single_prefill_with_kv_cache( not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ - check_rotary_mode(rotary_mode) + check_pos_encoding_mode(pos_encoding_mode) check_kv_layout(kv_layout) tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 8 * 1024 * 1024, q.device) if sm_scale is None: @@ -148,8 +148,8 @@ def single_prefill_with_kv_cache( v, tmp, causal, - getattr(TensorLayout, kv_layout), - getattr(RotaryMode, rotary_mode), + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -164,7 +164,7 @@ def single_prefill_with_kv_cache_return_lse( v: torch.Tensor, causal: bool = False, kv_layout: str = "NHD", - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -189,7 +189,7 @@ def single_prefill_with_kv_cache_return_lse( Whether to apply causal mask to the attention matrix. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). allow_fp16_qk_reduction : bool @@ -237,7 +237,7 @@ def single_prefill_with_kv_cache_return_lse( not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ - check_rotary_mode(rotary_mode) + check_pos_encoding_mode(pos_encoding_mode) check_kv_layout(kv_layout) tmp = _get_cache_buf( "single_prefill_with_kv_cache_return_lse_tmp", 8 * 1024 * 1024, q.device @@ -254,8 +254,8 @@ def single_prefill_with_kv_cache_return_lse( v, tmp, causal, - getattr(TensorLayout, kv_layout), - getattr(RotaryMode, rotary_mode), + TensorLayout[kv_layout].value, + PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -352,7 +352,7 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper( - getattr(TensorLayout, kv_layout) + TensorLayout[kv_layout].value ) self._qo_indptr = None self._paged_kv_indptr = None @@ -438,7 +438,7 @@ def forward( q: torch.Tensor, paged_kv_data: torch.Tensor, causal: bool = True, - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -458,7 +458,7 @@ def forward( if :attr:`kv_layout` is ``HND``. causal : bool Whether to apply causal mask to the attention matrix. - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). allow_fp16_qk_reduction : bool @@ -478,7 +478,7 @@ def forward( torch.Tensor The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ - check_rotary_mode(rotary_mode) + check_pos_encoding_mode(pos_encoding_mode) if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: @@ -494,7 +494,7 @@ def forward( self._paged_kv_indices, self._paged_kv_last_page_len, causal, - getattr(RotaryMode, rotary_mode), + PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -507,7 +507,7 @@ def forward_return_lse( q: torch.Tensor, paged_kv_data: torch.Tensor, causal: bool = True, - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -527,7 +527,7 @@ def forward_return_lse( :attr:`kv_layout` is ``HND``. causal : bool Whether to apply causal mask to the attention matrix. - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). allow_fp16_qk_reduction : bool @@ -550,7 +550,7 @@ def forward_return_lse( The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ - check_rotary_mode(rotary_mode) + check_pos_encoding_mode(pos_encoding_mode) if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: @@ -566,7 +566,7 @@ def forward_return_lse( self._paged_kv_indices, self._paged_kv_last_page_len, causal, - getattr(RotaryMode, rotary_mode), + PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -649,7 +649,7 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer self._wrapper = _kernels.BatchPrefillWithRaggedKVCachePyTorchWrapper( - getattr(TensorLayout, kv_layout) + TensorLayout[kv_layout].value ) self._qo_indptr = None self._kv_indptr = None @@ -723,7 +723,7 @@ def forward( k: torch.Tensor, v: torch.Tensor, causal: bool = True, - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -742,7 +742,7 @@ def forward( The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` causal : bool Whether to apply causal mask to the attention matrix. - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). allow_fp16_qk_reduction : bool @@ -762,7 +762,7 @@ def forward( torch.Tensor The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ - check_rotary_mode(rotary_mode) + check_pos_encoding_mode(pos_encoding_mode) if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: @@ -776,7 +776,7 @@ def forward( v, self._kv_indptr, causal, - getattr(RotaryMode, rotary_mode), + PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, rope_scale, @@ -790,7 +790,7 @@ def forward_return_lse( k: torch.Tensor, v: torch.Tensor, causal: bool = True, - rotary_mode: str = "NONE", + pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -809,7 +809,7 @@ def forward_return_lse( The value tensor, shape: ``[kv_indptr[-1], num_kv_heads, head_dim]`` causal : bool Whether to apply causal mask to the attention matrix. - rotary_mode : str + pos_encoding_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). allow_fp16_qk_reduction : bool @@ -831,7 +831,7 @@ def forward_return_lse( The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ - check_rotary_mode(rotary_mode) + check_pos_encoding_mode(pos_encoding_mode) if sm_scale is None: sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: @@ -845,7 +845,7 @@ def forward_return_lse( v, self._kv_indptr, causal, - getattr(RotaryMode, rotary_mode), + PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, rope_scale, diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index 14beb9be0..133e839be 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -15,21 +15,19 @@ """ import torch +from enum import Enum -class RotaryMode: +class PosEncodingMode(Enum): NONE = 0 - LLAMA = 1 + ROPE_LLAMA = 1 + ALIBI = 2 - FORMAT2STR = {0: "NONE", 1: "LLAMA"} - -class TensorLayout: +class TensorLayout(Enum): NHD = 0 HND = 1 - FORMAT2STR = {0: "NHD", 1: "HND"} - def expand_5d(x: torch.Tensor, kv_layout: str): if not x.ndim in [4, 5]: @@ -47,9 +45,9 @@ def expand_5d(x: torch.Tensor, kv_layout: str): return x -def check_rotary_mode(rotary_mode: str): - if not hasattr(RotaryMode, rotary_mode): - raise KeyError("Invalid rotary_mode {}".format(rotary_mode)) +def check_pos_encoding_mode(pos_encoding_mode: str): + if not hasattr(PosEncodingMode, pos_encoding_mode): + raise KeyError("Invalid pos_encoding_mode {}".format(pos_encoding_mode)) def check_kv_layout(kv_layout: str): diff --git a/python/setup.py b/python/setup.py index 090998cf7..dc3463eb5 100644 --- a/python/setup.py +++ b/python/setup.py @@ -55,7 +55,7 @@ def get_instantiation_cu() -> List[str]: causal_options = [False, True] allow_fp16_qk_reduction_options = [False, True] layout_options = ["HND", "NHD"] - rotary_mode_options = ["None", "Llama"] + pos_encoding_mode_options = ["None", "RoPELlama", "ALiBi"] # dispatch.inc path = root / prefix / "dispatch.inc" @@ -80,7 +80,7 @@ def get_instantiation_cu() -> List[str]: causal, allow_fp16_qk_reduction, layout, - rotary_mode, + pos_encoding_mode, ) in itertools.product( group_sizes, head_dims, @@ -88,10 +88,10 @@ def get_instantiation_cu() -> List[str]: causal_options, allow_fp16_qk_reduction_options, layout_options, - rotary_mode_options, + pos_encoding_mode_options, ): # paged batch prefill - fname = f"paged_batch_prefill_group{group_size}_head{head_dim}_causal{causal}_fp16qk{allow_fp16_qk_reduction}_layout{layout}_rotary{rotary_mode}_{dtype}.cu" + fname = f"paged_batch_prefill_group{group_size}_head{head_dim}_causal{causal}_fp16qk{allow_fp16_qk_reduction}_layout{layout}_pe{pos_encoding_mode}_{dtype}.cu" files.append(prefix + "/" + fname) if not (root / prefix / fname).exists(): with open(root / prefix / fname, "w") as f: @@ -106,12 +106,12 @@ def get_instantiation_cu() -> List[str]: str(causal).lower(), str(allow_fp16_qk_reduction).lower(), "QKVLayout::k" + layout, - "RotaryMode::k" + rotary_mode, + "PosEncodingMode::k" + pos_encoding_mode, ) ) # ragged batch prefill - fname = f"ragged_batch_prefill_group{group_size}_head{head_dim}_causal{causal}_fp16qk{allow_fp16_qk_reduction}_layout{layout}_rotary{rotary_mode}_{dtype}.cu" + fname = f"ragged_batch_prefill_group{group_size}_head{head_dim}_causal{causal}_fp16qk{allow_fp16_qk_reduction}_layout{layout}_pe{pos_encoding_mode}_{dtype}.cu" files.append(prefix + "/" + fname) if not (root / prefix / fname).exists(): with open(root / prefix / fname, "w") as f: @@ -126,12 +126,12 @@ def get_instantiation_cu() -> List[str]: str(causal).lower(), str(allow_fp16_qk_reduction).lower(), "QKVLayout::k" + layout, - "RotaryMode::k" + rotary_mode, + "PosEncodingMode::k" + pos_encoding_mode, ) ) # single prefill - fname = f"single_prefill_group{group_size}_head{head_dim}_causal{causal}_fp16qk{allow_fp16_qk_reduction}_layout{layout}_rotary{rotary_mode}_{dtype}.cu" + fname = f"single_prefill_group{group_size}_head{head_dim}_causal{causal}_fp16qk{allow_fp16_qk_reduction}_layout{layout}_pe{pos_encoding_mode}_{dtype}.cu" files.append(prefix + "/" + fname) if not (root / prefix / fname).exists(): with open(root / prefix / fname, "w") as f: @@ -146,7 +146,7 @@ def get_instantiation_cu() -> List[str]: str(causal).lower(), str(allow_fp16_qk_reduction).lower(), "QKVLayout::k" + layout, - "RotaryMode::k" + rotary_mode, + "PosEncodingMode::k" + pos_encoding_mode, ) ) diff --git a/python/tests/alibi_reference.py b/python/tests/alibi_reference.py new file mode 100644 index 000000000..cb4356b24 --- /dev/null +++ b/python/tests/alibi_reference.py @@ -0,0 +1,123 @@ +""" +Attention with Linear Biases (ALiBi) reference implementation. + +Code adapted from https://github.com/labmlai/annotated_deep_learning_paper_implementations + +Licensed under MIT, you may obtain a copy of the License at + + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license + +Source: +- https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/285cb3735bde02fbc8c19ddeb24d0ae7e77135c1/labml_nn/transformers/mha.py +- https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/285cb3735bde02fbc8c19ddeb24d0ae7e77135c1/labml_nn/transformers/alibi/__init__.py +""" + +import torch +import math +from torch import nn +from typing import Optional, List + + +def get_slopes(n_heads: int): + """ + ## Get head-specific slope $m$ for each head + + * `n_heads` is the number of heads in the attention layer $n$ + + The slope for first head is + + $$\frac{1}{2^{\frac{8}{n}}} = 2^{-\frac{8}{n}}$$ + + The slopes for the rest of the heads are in a geometric series with a ratio same as above. + + For instance when the number of heads is $8$ the slopes are + $$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$ + """ + + # Get the closest power of 2 to `n_heads`. + # If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2, + # and then add the remaining slopes. + n = 2 ** math.floor(math.log2(n_heads)) + # $2^{-\frac{8}{n}}$ + m_0 = 2.0 ** (-8.0 / n) + # $2^{-1\frac{8}{n}}, 2^{-2 \frac{8}{n}}, 2^{-3 \frac{8}{n}}, \dots$ + m = torch.pow(m_0, torch.arange(1, 1 + n)) + + # If `n_heads` is not a power of 2, then we add the remaining slopes. + # We calculate the remaining slopes for $n * 2$ (avoiding slopes added previously). + # And pick the slopes upto `n_heads`. + if n < n_heads: + # $2^{-\frac{8}{2n}}$ + m_hat_0 = 2.0 ** (-4.0 / n) + # $2^{-1\frac{8}{2n}}, 2^{-3 \frac{8}{2n}}, 2^{-5 \frac{8}{2n}}, \dots$ + # Note that we take steps by $2$ to avoid slopes added previously. + m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (n_heads - n), 2)) + # Concatenate the slopes with the remaining slopes. + m = torch.cat([m, m_hat]) + + return m + + +@torch.no_grad() +def get_alibi_biases(n_heads: int, mask: torch.Tensor): + """ + ## Calculate the attention biases matrix + + * `n_heads` is the number of heads in the attention layer + * `mask` is the attention mask of shape `[seq_len_q, seq_len_k]` + + This returns a matrix of shape `[seq_len_q, seq_len_k, n_heads, ]` with ALiBi attention biases. + """ + + # Get slopes $m$ for each head + m = get_slopes(n_heads).to(mask.device) + + # Calculate distances $[0, 1, \dots, N]$ + # Here we calculate the distances using the mask. + # + # Since it's causal mask we can just use $[0, 1, \dots, N]$ too. + distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[ + None, : + ] + + # Multiply them pair-wise to get the AliBi bias matrix + return distance[:, :, None] * m[None, None, :] + + +def alibi_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, +): + """ + query: [q_len, num_heads, head_dim] + key: [kv_len, num_heads, head_dim] + value: [kv_len, num_heads, head_dim] + mask: [q_len, kv_len] + """ + q_len, num_heads, head_dim = query.shape + kv_len = key.shape[0] + + scores = torch.einsum("qhd,khd->qkh", query.float(), key.float()) + # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$ + scores *= 1.0 / math.sqrt(head_dim) + + # Create AliBi biases if it's not cached + alibi_biases = get_alibi_biases(num_heads, mask) + + # Add AliBi biases to attention scores. + # ALiBi biases has shape `[seq_len, seq_len, n_heads]` + # and `scores` has shape `[seq_len, seq_len, batch_size, n_heads]` + scores += alibi_biases + + # Apply mask + scores = scores.masked_fill(mask.unsqueeze(-1) == 0, float("-inf")) + + # $softmax$ attention along the key sequence dimension + # $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$ + attn = torch.softmax(scores, dim=1) + + # Multiply by values + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$ + return torch.einsum("ovh,vhd->ohd", attn, value.float()).to(query) diff --git a/python/tests/test_alibi.py b/python/tests/test_alibi.py new file mode 100644 index 000000000..d96d21268 --- /dev/null +++ b/python/tests/test_alibi.py @@ -0,0 +1,78 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import numpy +import pytest +import torch + +import flashinfer + +from alibi_reference import alibi_attention + + +@pytest.mark.parametrize("seq_len", [1, 9, 81, 729, 33001]) +@pytest.mark.parametrize("num_heads", [4, 8, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +def test_single_decode_alibi( + seq_len, + num_heads, + head_dim, +): + q = torch.randn(num_heads, head_dim).to(0).half() + k = torch.randn(seq_len, num_heads, head_dim).to(0).half() + v = torch.randn(seq_len, num_heads, head_dim).to(0).half() + + o = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="ALIBI") + mask = torch.ones(1, seq_len, dtype=torch.bool).to(0) + o_ref = alibi_attention(q.unsqueeze(0), k, v, mask).squeeze(0) + numpy.testing.assert_allclose( + o.cpu().numpy(), o_ref.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + +@pytest.mark.parametrize("q_len", [1, 17, 81, 987]) +@pytest.mark.parametrize("kv_len", [1, 17, 81, 987, 31111]) +@pytest.mark.parametrize("num_heads", [4, 8, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("causal", [False, True]) +def test_single_prefill_alibi( + q_len, + kv_len, + num_heads, + head_dim, + causal, +): + if causal and q_len > kv_len: + pytest.skip("Causal attention requires q_len <= kv_len") + q = torch.randn(q_len, num_heads, head_dim).to(0).half() + k = torch.randn(kv_len, num_heads, head_dim).to(0).half() + v = torch.randn(kv_len, num_heads, head_dim).to(0).half() + + o = flashinfer.single_prefill_with_kv_cache( + q, k, v, causal, pos_encoding_mode="ALIBI" + ) + mask = torch.ones(q_len, kv_len, dtype=torch.bool).to(0) + if causal: + mask = torch.tril(mask, diagonal=kv_len - q_len) + o_ref = alibi_attention(q, k, v, mask) + numpy.testing.assert_allclose( + o.cpu().numpy(), o_ref.cpu().numpy(), rtol=1e-2, atol=1e-2 + ) + + +if __name__ == "__main__": + test_single_decode_alibi(9, 32, 128) + test_single_prefill_alibi(1, 64, 1, 128, False) diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index 0b83d2de2..213ceea30 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -29,6 +29,7 @@ @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) def test_batch_decode_with_paged_kv_cache( batch_size, kv_len, @@ -38,6 +39,7 @@ def test_batch_decode_with_paged_kv_cache( num_qo_heads, head_dim, kv_layout, + pos_encoding_mode, ): q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half() num_pages_per_seq = (kv_len + page_size - 1) // page_size @@ -68,7 +70,7 @@ def test_batch_decode_with_paged_kv_cache( "NONE", "float16", ) - o = wrapper.forward(q, kv_data) + o = wrapper.forward(q, kv_data, pos_encoding_mode=pos_encoding_mode) for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] @@ -104,12 +106,14 @@ def test_batch_decode_with_paged_kv_cache( ], dim=0, ) - o_ref_i = flashinfer.single_decode_with_kv_cache(qi, ki, vi) + o_ref_i = flashinfer.single_decode_with_kv_cache( + qi, ki, vi, pos_encoding_mode=pos_encoding_mode + ) o_i_np = o[i].cpu().numpy() o_ref_i_np = o_ref_i.cpu().numpy() numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) if __name__ == "__main__": - test_batch_decode_with_paged_kv_cache(12, 54, 37, 8, 8, 8, 128, "HND") - test_batch_decode_with_paged_kv_cache(12, 54, 37, 1, 8, 8, 128, "HND") + test_batch_decode_with_paged_kv_cache(12, 54, 37, 8, 8, 8, 128, "HND", "NONE") + test_batch_decode_with_paged_kv_cache(12, 54, 37, 1, 8, 8, 128, "HND", "NONE") diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index e1e4f7ba9..1bb2819c9 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -30,6 +30,7 @@ @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) def test_batch_prefill_with_paged_kv_cache( batch_size, kv_len, @@ -40,6 +41,7 @@ def test_batch_prefill_with_paged_kv_cache( head_dim, causal, kv_layout, + pos_encoding_mode, ): q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len @@ -71,7 +73,7 @@ def test_batch_prefill_with_paged_kv_cache( num_kv_heads, head_dim, ) - o = wrapper.forward(q, kv_data, causal=causal) + o = wrapper.forward(q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode) for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] @@ -107,7 +109,9 @@ def test_batch_prefill_with_paged_kv_cache( ], dim=0, ) - o_ref_i = flashinfer.single_prefill_with_kv_cache(qi, ki, vi, causal=causal) + o_ref_i = flashinfer.single_prefill_with_kv_cache( + qi, ki, vi, causal=causal, pos_encoding_mode=pos_encoding_mode + ) o_i_np = o[q_indptr[i] : q_indptr[i + 1]].cpu().numpy() o_ref_i_np = o_ref_i.cpu().numpy() numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) @@ -120,8 +124,16 @@ def test_batch_prefill_with_paged_kv_cache( @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) def test_batch_prefill_with_ragged_kv_cache( - batch_size, kv_len, qo_len, num_kv_heads, num_qo_heads, head_dim, causal + batch_size, + kv_len, + qo_len, + num_kv_heads, + num_qo_heads, + head_dim, + causal, + pos_encoding_mode, ): kv_layout = "NHD" q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() @@ -142,7 +154,7 @@ def test_batch_prefill_with_ragged_kv_cache( num_kv_heads, head_dim, ) - o = wrapper.forward(q, k, v, causal=causal) + o = wrapper.forward(q, k, v, causal=causal, pos_encoding_mode=pos_encoding_mode) for i in range(batch_size): o_ref_i = flashinfer.single_prefill_with_kv_cache( @@ -150,6 +162,7 @@ def test_batch_prefill_with_ragged_kv_cache( k[kv_indptr[i] : kv_indptr[i + 1]], v[kv_indptr[i] : kv_indptr[i + 1]], causal=causal, + pos_encoding_mode=pos_encoding_mode, ) o_i_np = o[q_indptr[i] : q_indptr[i + 1]].cpu().numpy() o_ref_i_np = o_ref_i.cpu().numpy() @@ -157,6 +170,10 @@ def test_batch_prefill_with_ragged_kv_cache( if __name__ == "__main__": - test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 8, 128, True, "HND") - test_batch_prefill_with_paged_kv_cache(12, 54, 37, 1, 8, 8, 128, True, "HND") - test_batch_prefill_with_ragged_kv_cache(12, 54, 37, 8, 8, 128, True) + test_batch_prefill_with_paged_kv_cache( + 12, 54, 37, 8, 8, 8, 128, True, "HND", "NONE" + ) + test_batch_prefill_with_paged_kv_cache( + 12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE" + ) + test_batch_prefill_with_ragged_kv_cache(12, 54, 37, 8, 8, 128, True, "NONE") diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index e5b9445e5..160fe3a87 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -30,7 +30,7 @@ constexpr QKVLayout kv_layout = QKVLayout::kNHD; template void bench_flashinfer_batch_decode(nvbench::state& state) { constexpr size_t head_dim = 128; - constexpr auto rotary_mode = RotaryMode::kNone; + constexpr auto pos_encoding_mode = PosEncodingMode::kNone; size_t seqlen = state.get_int64("seqlen"); size_t batch_size = state.get_int64("batch_size"); size_t page_size = state.get_int64("page_size"); @@ -76,12 +76,12 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { handler.BeginForward( (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size, rotary_mode); + head_dim, page_size, pos_encoding_mode); state.exec([&](nvbench::launch&) { cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q.data()), /*q_rope_position=*/nullptr, paged_kv, - thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, rotary_mode); + &handler, thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv, + thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); } @@ -90,9 +90,9 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { state.exec([&](nvbench::launch&) { cudaError_t status = BatchDecodeWithPagedKVCache( - thrust::raw_pointer_cast(q.data()), /*q_rope_position=*/nullptr, paged_kv, + thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv, kv_partition_info_t(), thrust::raw_pointer_cast(o.data()), nullptr, - /*lse=*/nullptr, num_qo_heads, rotary_mode); + /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); } @@ -103,7 +103,7 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { template void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { constexpr size_t head_dim = 128; - constexpr auto rotary_mode = RotaryMode::kNone; + constexpr auto pos_encoding_mode = PosEncodingMode::kNone; size_t seqlen = state.get_int64("seqlen"); size_t batch_size = state.get_int64("batch_size"); size_t page_size = state.get_int64("page_size"); @@ -157,9 +157,9 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { BatchPrefillWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), - /*q_rope_position=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), + /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, - /*causal=*/false, rotary_mode); + /*causal=*/false, pos_encoding_mode); }); } diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index d5e939f0a..5e6d64803 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -112,7 +112,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { cascade_handler.BeginForward( (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, RotaryMode::kNone); + num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); @@ -123,7 +123,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { thrust::raw_pointer_cast(lse_cascade_0_d.data()), num_qo_heads, num_kv_heads, /*qo_len=*/batch_size, /*kv_len=*/shared_prefix_length, head_dim, /*causal=*/false, /*kv_layout=*/QKVLayout::kNHD, - /*rotary_mode=*/RotaryMode::kNone, /*allow_fp16_qk_reduction=*/false); + /*pos_encoding_mode=*/PosEncodingMode::kNone, /*allow_fp16_qk_reduction=*/false); if (status != cudaSuccess) { state.skip("Cascade implementation prefill failed with error: " + @@ -132,10 +132,9 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { status = BatchDecodeWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), - /*q_rope_position=*/nullptr, paged_kv_casacde_d, - thrust::raw_pointer_cast(o_cascade_1_d.data()), + /*q_offset=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), /*lse=*/thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, - RotaryMode::kNone); + PosEncodingMode::kNone); if (status != cudaSuccess) { state.skip("Cascade implementation decode failed with error: " + @@ -170,16 +169,16 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { baseline_handler.BeginForward( (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, RotaryMode::kNone); + num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), - /*q_rope_position=*/nullptr, paged_kv_baseline_d, + /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), - /*lse=*/nullptr, num_qo_heads, RotaryMode::kNone); + /*lse=*/nullptr, num_qo_heads, PosEncodingMode::kNone); if (status != cudaSuccess) { state.skip("Cascade implementation decode failed with error: " + std::string(cudaGetErrorString(status))); @@ -259,7 +258,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { /*qo_len=*/batch_size * qo_append_length, /*kv_len=*/shared_prefix_length, head_dim, /*causal=*/false, /*kv_layout=*/QKVLayout::kNHD, - /*rotary_mode=*/RotaryMode::kNone, /*allow_fp16_qk_reduction=*/false); + /*pos_encoding_mode=*/PosEncodingMode::kNone, /*allow_fp16_qk_reduction=*/false); if (status != cudaSuccess) { state.skip("Cascade implementation prefill failed with error: " + @@ -269,10 +268,9 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { status = BatchPrefillWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), - /*q_rope_position=*/nullptr, paged_kv_casacde_d, - thrust::raw_pointer_cast(o_cascade_1_d.data()), + /*q_offset=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, /*causal=*/true, - RotaryMode::kNone, /*allow_fp16_qk_reduction=*/false); + PosEncodingMode::kNone, /*allow_fp16_qk_reduction=*/false); if (status != cudaSuccess) { state.skip("Cascade implementation unique kv prefill failed with error: " + @@ -312,9 +310,9 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), - /*q_rope_position=*/nullptr, paged_kv_baseline_d, + /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), - /*lse=*/nullptr, num_qo_heads, /*causal=*/true, RotaryMode::kNone, + /*lse=*/nullptr, num_qo_heads, /*causal=*/true, PosEncodingMode::kNone, /*allow_fp16_qk_reduction=*/false); if (status != cudaSuccess) { diff --git a/src/bench_single_decode.cu b/src/bench_single_decode.cu index d157e6eee..9b93b92f9 100644 --- a/src/bench_single_decode.cu +++ b/src/bench_single_decode.cu @@ -19,8 +19,8 @@ #include #include +using flashinfer::PosEncodingMode; using flashinfer::QKVLayout; -using flashinfer::RotaryMode; template void bench_flashinfer_single_decode(nvbench::state& state) { @@ -28,7 +28,7 @@ void bench_flashinfer_single_decode(nvbench::state& state) { size_t num_qo_heads = state.get_int64("num_qo_heads"); size_t num_kv_heads = state.get_int64("num_kv_heads"); size_t head_dim = state.get_int64("head_dim"); - size_t rotary_mode = state.get_int64("rotary_mode"); + size_t pos_encoding_mode = state.get_int64("pos_encoding_mode"); size_t kv_layout = state.get_int64("kv_layout"); bool cooperative = state.get_int64("cooperative"); // Allocate input data: @@ -49,7 +49,7 @@ void bench_flashinfer_single_decode(nvbench::state& state) { thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(K.data()), thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(O.data()), cooperative ? thrust::raw_pointer_cast(tmp.data()) : nullptr, num_qo_heads, num_kv_heads, - seq_len, head_dim, QKVLayout(kv_layout), RotaryMode(rotary_mode), + seq_len, head_dim, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), /*maybe_sm_scale=*/std::nullopt, /*rope_scale=*/1.f, /*rope_theta=*/1e4, launch.get_stream()); @@ -68,7 +68,7 @@ void bench_flashinfer_single_decode_with_prefill(nvbench::state& state) { size_t num_qo_heads = state.get_int64("num_qo_heads"); size_t num_kv_heads = state.get_int64("num_kv_heads"); size_t head_dim = state.get_int64("head_dim"); - size_t rotary_mode = state.get_int64("rotary_mode"); + size_t pos_encoding_mode = state.get_int64("pos_encoding_mode"); size_t kv_layout = state.get_int64("kv_layout"); bool cooperative = state.get_int64("cooperative"); // Allocate input data: @@ -92,7 +92,7 @@ void bench_flashinfer_single_decode_with_prefill(nvbench::state& state) { /*lse=*/nullptr, num_qo_heads, num_kv_heads, /*qo_len=*/1, /*kv_len=*/seq_len, head_dim, - /*causal=*/false, QKVLayout(kv_layout), RotaryMode(rotary_mode), + /*causal=*/false, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), /*allow_fp16_qk_reduction=*/false, /*maybe_sm_scale=*/std::nullopt, /*rope_scale=*/1.f, @@ -116,7 +116,7 @@ void bench_flashinfer_single_decode_with_prefill(nvbench::state& state) { .add_int64_axis("num_qo_heads", {32}) \ .add_int64_axis("num_kv_heads", {32, 4}) \ .add_int64_axis("head_dim", {128}) \ - .add_int64_axis("rotary_mode", {0, 1}) \ + .add_int64_axis("pos_encoding_mode", {0, 1}) \ .add_int64_axis("kv_layout", {0, 1}) \ .add_int64_axis("cooperative", {1}) @@ -130,7 +130,7 @@ void bench_flashinfer_single_decode_with_prefill(nvbench::state& state) { .add_int64_axis("num_qo_heads", {32}) \ .add_int64_axis("num_kv_heads", {32, 4}) \ .add_int64_axis("head_dim", {128}) \ - .add_int64_axis("rotary_mode", {0, 1}) \ + .add_int64_axis("pos_encoding_mode", {0, 1}) \ .add_int64_axis("kv_layout", {0, 1}) \ .add_int64_axis("cooperative", {1}) diff --git a/src/bench_single_prefill.cu b/src/bench_single_prefill.cu index 14466583a..98d141c94 100644 --- a/src/bench_single_prefill.cu +++ b/src/bench_single_prefill.cu @@ -18,8 +18,8 @@ #include #include +using flashinfer::PosEncodingMode; using flashinfer::QKVLayout; -using flashinfer::RotaryMode; template void bench_flashinfer_single_prefill(nvbench::state& state) { @@ -34,7 +34,7 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { size_t num_qo_heads = state.get_int64("num_qo_heads"); size_t num_kv_heads = state.get_int64("num_kv_heads"); size_t head_dim = state.get_int64("head_dim"); - size_t rotary_mode = state.get_int64("rotary_mode"); + size_t pos_encoding_mode = state.get_int64("pos_encoding_mode"); size_t kv_layout = state.get_int64("kv_layout"); bool causal = state.get_int64("causal"); bool cooperative = state.get_int64("cooperative"); @@ -58,7 +58,7 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(O.data()), /*tmp=*/cooperative ? thrust::raw_pointer_cast(tmp.data()) : nullptr, /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, - QKVLayout(kv_layout), RotaryMode(rotary_mode), allow_fp16_qk_reduction, + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction, /*maybe_sm_scale=*/std::nullopt, /*rope_scale=*/1.f, /*rope_theta=*/1e4, launch.get_stream()); @@ -96,7 +96,7 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { .add_int64_axis("head_dim", {128}) \ .add_int64_axis("causal", {0, 1}) \ .add_int64_axis("kv_layout", {0, 1}) \ - .add_int64_axis("rotary_mode", {0, 1}) \ + .add_int64_axis("pos_encoding_mode", {0, 1}) \ .add_int64_axis("allow_fp16_qk_reduction", {0, 1}) \ .add_int64_axis("cooperative", {1}) @@ -112,7 +112,7 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { .add_int64_axis("head_dim", {128}) \ .add_int64_axis("causal", {0, 1}) \ .add_int64_axis("kv_layout", {0, 1}) \ - .add_int64_axis("rotary_mode", {0, 1}) \ + .add_int64_axis("pos_encoding_mode", {0, 1}) \ .add_int64_axis("allow_fp16_qk_reduction", {0, 1}) \ .add_int64_axis("cooperative", {0, 1}) diff --git a/src/cpu_reference.h b/src/cpu_reference.h index a1cacd204..cd3c1c6be 100644 --- a/src/cpu_reference.h +++ b/src/cpu_reference.h @@ -16,7 +16,7 @@ #pragma once #include -#include +#include #include #include "utils.h" @@ -49,7 +49,7 @@ std::vector single_mha(const std::vector& q, const std::vec const std::vector& v, size_t qo_len, size_t kv_len, size_t num_q_heads, size_t num_kv_heads, size_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kHND, - RotaryMode rotary_mode = RotaryMode::kNone, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, float rope_scale = 1.f, float rope_theta = 1e4) { assert(qo_len <= kv_len); assert(num_q_heads % num_kv_heads == 0); @@ -67,15 +67,15 @@ std::vector single_mha(const std::vector& q, const std::vec const size_t kv_head_idx = qo_head_idx / GROUP_SIZE; for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { float max_val = -5e4; - if (rotary_mode == RotaryMode::kLlama) { + if (pos_encoding_mode == PosEncodingMode::kRoPELlama) { q_rotary_local = std::move(cpu_reference::apply_llama_rope( q.data() + info.get_qo_elem_offset(q_idx, qo_head_idx, 0), head_dim, q_idx + kv_len - qo_len, rope_scale, rope_theta)); } for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { att[kv_idx] = 0.; - switch (rotary_mode) { - case RotaryMode::kNone: { + switch (pos_encoding_mode) { + case PosEncodingMode::kNone: { for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { att[kv_idx] += float(q[info.get_qo_elem_offset(q_idx, qo_head_idx, feat_idx)]) * @@ -84,7 +84,7 @@ std::vector single_mha(const std::vector& q, const std::vec } break; } - case RotaryMode::kLlama: { + case PosEncodingMode::kRoPELlama: { k_rotary_local = std::move(cpu_reference::apply_llama_rope( k.data() + info.get_kv_elem_offset(kv_idx, kv_head_idx, 0), head_dim, kv_idx, rope_scale, rope_theta)); diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 22e17c428..cf509e9ef 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -28,7 +28,8 @@ constexpr QKVLayout kv_layout = QKVLayout::kNHD; template void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, - flashinfer::RotaryMode rotary_mode, bool cooperative) { + flashinfer::PosEncodingMode pos_encoding_mode, + bool cooperative) { std::vector seq_lens(batch_size); utils::vec_randint_(seq_lens, 1, 1024); std::vector append_indptr{0}; @@ -57,7 +58,7 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si // compute reference output std::vector o_ref_i = cpu_reference::single_mha(qi, ki, vi, 1, seq_len, num_qo_heads, num_kv_heads, - head_dim, false, QKVLayout::kNHD, rotary_mode); + head_dim, false, QKVLayout::kNHD, pos_encoding_mode); keys.push_back(ki); values.push_back(vi); // append new q and o_ref @@ -102,21 +103,22 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si handler.BeginForward( (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, kv_indptr.data(), kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, - rotary_mode); + pos_encoding_mode); if (!cooperative) { // use non-cooperative kernel cudaError_t status = flashinfer::BatchDecodeWithPagedKVCache( - thrust::raw_pointer_cast(q_device.data()), /*q_rope_position=*/nullptr, paged_kv, + thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, kv_partition_info_t(), thrust::raw_pointer_cast(o_device.data()), - /*tmp=*/nullptr, /*lse=*/nullptr, num_qo_heads, rotary_mode); + /*tmp=*/nullptr, /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } else { cudaError_t status = flashinfer::BatchDecodeWithPagedKVCacheWrapper( - &handler, thrust::raw_pointer_cast(q_device.data()), /*q_rope_position=*/nullptr, paged_kv, - thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, rotary_mode); + &handler, thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, + thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, + pos_encoding_mode); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } // compare result @@ -135,7 +137,7 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si std::cout << "page_size=" << page_size << ", num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads << ", batch_size=" << batch_size << ", head_dim=" << head_dim - << ", rotary_mode=" << flashinfer::RotaryModeToString(rotary_mode) + << ", pos_encoding_mode=" << flashinfer::PosEncodingModeToString(pos_encoding_mode) << ", result accuracy (atol=1e-3, rtol=1e-3): " << result_accuracy << std::endl; EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; EXPECT_EQ(nan_detected, false) << "NaN detected."; @@ -148,10 +150,10 @@ void TestBatchDecodeKernelCorrectness() { for (size_t num_qo_heads : {32}) { for (size_t num_kv_heads : {32, 8, 4}) { for (size_t head_dim : {64, 128, 256}) { - for (size_t rotary_mode : {0U, 1U}) { - _TestBatchDecodingKernelCorrectness(page_size, batch_size, num_qo_heads, - num_kv_heads, head_dim, - flashinfer::RotaryMode(rotary_mode), false); + for (size_t pos_encoding_mode : {0U, 1U}) { + _TestBatchDecodingKernelCorrectness( + page_size, batch_size, num_qo_heads, num_kv_heads, head_dim, + flashinfer::PosEncodingMode(pos_encoding_mode), false); } } } @@ -167,10 +169,10 @@ void TestCooperativeBatchDecodeKernelCorrectness() { for (size_t num_qo_heads : {32}) { for (size_t num_kv_heads : {32, 8, 4}) { for (size_t head_dim : {64, 128, 256}) { - for (size_t rotary_mode : {0U, 1U}) { - _TestBatchDecodingKernelCorrectness(page_size, batch_size, num_qo_heads, - num_kv_heads, head_dim, - flashinfer::RotaryMode(rotary_mode), true); + for (size_t pos_encoding_mode : {0U, 1U}) { + _TestBatchDecodingKernelCorrectness( + page_size, batch_size, num_qo_heads, num_kv_heads, head_dim, + flashinfer::PosEncodingMode(pos_encoding_mode), true); } } } diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index b5bb8c787..aafe1cc5a 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -27,7 +27,7 @@ constexpr QKVLayout kv_layout = QKVLayout::kNHD; template void _TestBatchPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim, bool causal, - RotaryMode rotary_mode, + PosEncodingMode pos_encoding_mode, bool allow_fp16_qk_reduction) { uint32_t batch_size = 9; std::vector q_lens(batch_size), kv_lens(batch_size); @@ -92,7 +92,7 @@ void _TestBatchPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t num_qo std::vector o_ref = cpu_reference::single_mha( q, key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads, num_kv_heads, - head_dim, causal, QKVLayout::kNHD, rotary_mode); + head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); thrust::device_vector q_indptr_device(q_indptr); thrust::device_vector q_device(q); @@ -101,9 +101,9 @@ void _TestBatchPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t num_qo for (uint32_t num_runs = 0; num_runs < 10; ++num_runs) { auto status = BatchPrefillWithPagedKVCache( thrust::raw_pointer_cast(q_device.data()), - thrust::raw_pointer_cast(q_indptr_device.data()), /*q_rope_position=*/nullptr, paged_kv, + thrust::raw_pointer_cast(q_indptr_device.data()), /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), /*tmp=*/nullptr, - /*lse=*/nullptr, num_qo_heads, causal, rotary_mode, allow_fp16_qk_reduction); + /*lse=*/nullptr, num_qo_heads, causal, pos_encoding_mode, allow_fp16_qk_reduction); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } @@ -122,7 +122,8 @@ void _TestBatchPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t num_qo std::cout << "request_idx=" << request_idx << ", page_size=" << page_size << ", num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads << ", q_len=" << q_len << ", kv_len=" << kv_len << ", head_dim=" << head_dim - << ", causal=" << causal << ", rotary_mode=" << RotaryModeToString(rotary_mode) + << ", causal=" << causal + << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) << ", result_accuracy=" << result_accuracy << std::endl; EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed."; EXPECT_EQ(nan_detected, false) << "NaN detected in output."; @@ -132,7 +133,7 @@ void _TestBatchPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t num_qo template void _TestBatchPrefillKernelShortContextCorrectness(size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim, bool causal, - RotaryMode rotary_mode, + PosEncodingMode pos_encoding_mode, bool allow_fp16_qk_reduction) { uint32_t batch_size = 7; std::vector q_lens(batch_size); @@ -200,7 +201,7 @@ void _TestBatchPrefillKernelShortContextCorrectness(size_t num_kv_heads, size_t int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx]; std::vector o_ref_i = cpu_reference::single_mha( q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads, - num_kv_heads, head_dim, causal, QKVLayout::kNHD, rotary_mode); + num_kv_heads, head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); o_ref.push_back(o_ref_i); } @@ -216,9 +217,9 @@ void _TestBatchPrefillKernelShortContextCorrectness(size_t num_kv_heads, size_t auto status = BatchPrefillWithPagedKVCache( thrust::raw_pointer_cast(q_device.data()), thrust::raw_pointer_cast(q_indptr_device.data()), - /*q_rope_position=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), + /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), /*tmp=*/nullptr, - /*lse=*/nullptr, num_qo_heads, causal, rotary_mode, allow_fp16_qk_reduction); + /*lse=*/nullptr, num_qo_heads, causal, pos_encoding_mode, allow_fp16_qk_reduction); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); thrust::host_vector o_host(o_device); @@ -235,7 +236,8 @@ void _TestBatchPrefillKernelShortContextCorrectness(size_t num_kv_heads, size_t 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) / max(float(o_concat_ref.size()), 1.f); std::cout << "page_size=" << page_size << ", num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads << ", head_dim=" << head_dim - << ", causal=" << causal << ", rotary_mode=" << RotaryModeToString(rotary_mode) + << ", causal=" << causal + << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) << ", result_accuracy=" << result_accuracy << std::endl; EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed."; EXPECT_EQ(nan_detected, false) << "NaN detected in output."; @@ -244,7 +246,7 @@ void _TestBatchPrefillKernelShortContextCorrectness(size_t num_kv_heads, size_t template void _TestBatchPrefillKernelLongContextCorrectness(size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim, bool causal, - RotaryMode rotary_mode, + PosEncodingMode pos_encoding_mode, bool allow_fp16_qk_reduction) { std::vector>> keys, values; std::vector q_lens{63}, kv_lens{2047}; @@ -293,7 +295,7 @@ void _TestBatchPrefillKernelLongContextCorrectness(size_t num_kv_heads, size_t n std::vector o_ref = cpu_reference::single_mha(q, k, v, q_lens[0], kv_lens[0], num_qo_heads, num_kv_heads, - head_dim, causal, QKVLayout::kNHD, rotary_mode); + head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); thrust::device_vector q_indptr_device(q_indptr); thrust::device_vector q_device(q); @@ -301,8 +303,9 @@ void _TestBatchPrefillKernelLongContextCorrectness(size_t num_kv_heads, size_t n auto status = BatchPrefillWithPagedKVCache( thrust::raw_pointer_cast(q_device.data()), thrust::raw_pointer_cast(q_indptr_device.data()), - /*q_rope_position=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), - /*tmp=*/nullptr, /*lse=*/nullptr, num_qo_heads, causal, rotary_mode, allow_fp16_qk_reduction); + /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), + /*tmp=*/nullptr, /*lse=*/nullptr, num_qo_heads, causal, pos_encoding_mode, + allow_fp16_qk_reduction); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); thrust::host_vector o_host(o_device); @@ -320,7 +323,7 @@ void _TestBatchPrefillKernelLongContextCorrectness(size_t num_kv_heads, size_t n std::cout << ", page_size=" << page_size << ", num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads << ", q_len=" << q_lens[0] << ", kv_len=" << kv_lens[0] << ", head_dim=" << head_dim << ", causal=" << causal - << ", rotary_mode=" << RotaryModeToString(rotary_mode) + << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) << ", result_accuracy=" << result_accuracy << std::endl; EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed."; EXPECT_EQ(nan_detected, false) << "NaN detected in output."; @@ -333,10 +336,10 @@ void TestBatchPrefillKernelOneHotCorrectness(bool allow_fp16_qk_reduction) { for (size_t page_size : {1, 7, 16}) { for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { - for (size_t rotary_mode : {0, 1}) { - _TestBatchPrefillKernelOneHotCorrectness(num_kv_heads, num_qo_heads, page_size, - head_dim, causal, RotaryMode(rotary_mode), - allow_fp16_qk_reduction); + for (size_t pos_encoding_mode : {0, 1}) { + _TestBatchPrefillKernelOneHotCorrectness( + num_kv_heads, num_qo_heads, page_size, head_dim, causal, + PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } } } @@ -352,10 +355,10 @@ void TestBatchPrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction) for (size_t page_size : {1, 7, 16}) { for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { - for (size_t rotary_mode : {0, 1}) { + for (size_t pos_encoding_mode : {0, 1}) { _TestBatchPrefillKernelShortContextCorrectness( - num_kv_heads, num_qo_heads, page_size, head_dim, causal, RotaryMode(rotary_mode), - allow_fp16_qk_reduction); + num_kv_heads, num_qo_heads, page_size, head_dim, causal, + PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } } } @@ -371,10 +374,10 @@ void TestBatchPrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduction) for (size_t page_size : {1, 7, 16}) { for (size_t head_dim : {64, 128, 256}) { for (size_t causal : {false, true}) { - for (size_t rotary_mode : {0, 1}) { + for (size_t pos_encoding_mode : {0, 1}) { _TestBatchPrefillKernelLongContextCorrectness( - num_kv_heads, num_qo_heads, page_size, head_dim, causal, RotaryMode(rotary_mode), - allow_fp16_qk_reduction); + num_kv_heads, num_qo_heads, page_size, head_dim, causal, + PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } } } diff --git a/src/test_cascade.cu b/src/test_cascade.cu index 21f561b4b..6201f4373 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -286,19 +286,18 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, baseline_handler.BeginForward( (void*)thrust::raw_pointer_cast(buffer_baseline.data()), workspace_size_in_bytes, kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, RotaryMode::kNone); + num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); cascade_handler.BeginForward( (void*)thrust::raw_pointer_cast(buffer_cascade.data()), workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, RotaryMode::kNone); + num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); // Compute result using baseline implementation cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), - /*q_rope_position=*/nullptr, paged_kv_baseline_d, - thrust::raw_pointer_cast(o_baseline_d.data()), - /*lse=*/nullptr, num_qo_heads, RotaryMode::kNone); + /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), + /*lse=*/nullptr, num_qo_heads, PosEncodingMode::kNone); EXPECT_EQ(status, cudaSuccess) << "Baseline implementation failed with error: " << cudaGetErrorString(status); @@ -310,16 +309,16 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, thrust::raw_pointer_cast(tmp_0_d.data()), thrust::raw_pointer_cast(lse_cascade_0_d.data()), num_qo_heads, num_kv_heads, /*qo_len=*/batch_size, /*kv_len=*/shared_prefix_length, head_dim, /*causal=*/false, /*kv_layout=*/QKVLayout::kNHD, - /*rotary_mode=*/RotaryMode::kNone, /*allow_fp16_qk_reduction=*/false); + /*pos_encoding_mode=*/PosEncodingMode::kNone, /*allow_fp16_qk_reduction=*/false); EXPECT_EQ(status, cudaSuccess) << "Cascade implementation prefill failed with error: " << cudaGetErrorString(status); status = BatchDecodeWithPagedKVCacheWrapper( &cascade_handler, thrust::raw_pointer_cast(q_d.data()), - /*q_rope_position=*/nullptr, paged_kv_casacde_d, - thrust::raw_pointer_cast(o_cascade_1_d.data()), - /*lse=*/thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, RotaryMode::kNone); + /*q_offset=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), + /*lse=*/thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, + PosEncodingMode::kNone); EXPECT_EQ(status, cudaSuccess) << "Cascade implementation decode failed with error: " << cudaGetErrorString(status); @@ -418,9 +417,8 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), - /*q_rope_position=*/nullptr, paged_kv_baseline_d, - thrust::raw_pointer_cast(o_baseline_d.data()), - /*lse=*/nullptr, num_qo_heads, /*causal=*/true, RotaryMode::kNone, + /*q_offset=*/nullptr, paged_kv_baseline_d, thrust::raw_pointer_cast(o_baseline_d.data()), + /*lse=*/nullptr, num_qo_heads, /*causal=*/true, PosEncodingMode::kNone, /*allow_fp16_qk_reduction=*/false); EXPECT_EQ(status, cudaSuccess) << "Baseline implementation failed with error: " @@ -433,7 +431,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, num_qo_heads, num_kv_heads, /*qo_len=*/batch_size * qo_append_length, /*kv_len=*/shared_prefix_length, head_dim, /*causal=*/false, /*kv_layout=*/QKVLayout::kNHD, - /*rotary_mode=*/RotaryMode::kNone, /*allow_fp16_qk_reduction=*/false); + /*pos_encoding_mode=*/PosEncodingMode::kNone, /*allow_fp16_qk_reduction=*/false); EXPECT_EQ(status, cudaSuccess) << "Cascade implementation shared prefix prefill failed with error: " @@ -445,7 +443,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, /*r_rope_position=*/nullptr, paged_kv_casacde_d, thrust::raw_pointer_cast(o_cascade_1_d.data()), thrust::raw_pointer_cast(lse_cascade_1_d.data()), num_qo_heads, /*causal=*/true, - RotaryMode::kNone, /*allow_fp16_qk_reduction=*/false); + PosEncodingMode::kNone, /*allow_fp16_qk_reduction=*/false); EXPECT_EQ(status, cudaSuccess) << "Cascade implementation unique kv prefill failed with error: " << cudaGetErrorString(status); diff --git a/src/test_single_decode.cu b/src/test_single_decode.cu index 87b9db6fe..9bcb73b09 100644 --- a/src/test_single_decode.cu +++ b/src/test_single_decode.cu @@ -25,7 +25,8 @@ using namespace flashinfer; template void _TestDecodingKernelCorrectness(size_t num_qo_heads, size_t num_kv_heads, size_t seq_len, - size_t head_dim, QKVLayout kv_layout, RotaryMode rotary_mode) { + size_t head_dim, QKVLayout kv_layout, + PosEncodingMode pos_encoding_mode) { std::vector Q_host(num_qo_heads * head_dim); std::vector K_host(seq_len * num_kv_heads * head_dim); std::vector V_host(seq_len * num_kv_heads * head_dim); @@ -45,13 +46,13 @@ void _TestDecodingKernelCorrectness(size_t num_qo_heads, size_t num_kv_heads, si o_ref_host = cpu_reference::single_mha(Q_host, K_host, V_host, 1, seq_len, num_qo_heads, - num_kv_heads, head_dim, false, kv_layout, rotary_mode); + num_kv_heads, head_dim, false, kv_layout, pos_encoding_mode); cudaError_t status = SingleDecodeWithKVCache( thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(K.data()), thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(O.data()), thrust::raw_pointer_cast(tmp.data()), num_qo_heads, num_kv_heads, seq_len, head_dim, - kv_layout, rotary_mode); + kv_layout, pos_encoding_mode); EXPECT_EQ(status, cudaSuccess) << "SingleDecodeWithKVCache kernel launch failed, error message: " << cudaGetErrorString(status); @@ -72,7 +73,7 @@ void _TestDecodingKernelCorrectness(size_t num_qo_heads, size_t num_kv_heads, si std::cout << "num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads << ", seq_len=" << seq_len << ", head_dim=" << head_dim << ", kv_layout=" << QKVLayoutToString(kv_layout) - << ", rotary_mode=" << RotaryModeToString(rotary_mode) + << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) << ", result accuracy (atol=1e-3, rtol=1e-3): " << result_accuracy << std::endl; EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; EXPECT_FALSE(nan_detected) << "NaN detected."; @@ -86,9 +87,10 @@ void TestSingleDecodeKernelCorrectness() { {1, 3, 9, 27, 81, 129, 257, 512, 1024, 2048, 4096, 8192, 16384, 32768}) { for (size_t head_dim : {64, 128, 256}) { for (unsigned int kv_layout : {0U, 1U}) { - for (unsigned int rotary_mode : {0U, 1U}) { + for (unsigned int pos_encoding_mode : {0U, 1U}) { _TestDecodingKernelCorrectness(num_qo_heads, num_kv_heads, seq_len, head_dim, - QKVLayout(kv_layout), RotaryMode(rotary_mode)); + QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode)); } } } diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index c22b5591b..3e69f8c7d 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -25,7 +25,7 @@ using namespace flashinfer; template void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, bool causal, - QKVLayout kv_layout, RotaryMode rotary_mode, + QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, bool allow_fp16_qk_reduction, float rtol = 1e-3, float atol = 1e-3) { std::vector q(qo_len * num_qo_heads * head_dim); @@ -49,7 +49,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu thrust::raw_pointer_cast(v_d.data()), thrust::raw_pointer_cast(o_d.data()), thrust::raw_pointer_cast(tmp_d.data()), /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, kv_layout, - rotary_mode, allow_fp16_qk_reduction); + pos_encoding_mode, allow_fp16_qk_reduction); EXPECT_EQ(status, cudaSuccess) << "SinglePrefillWithKVCache kernel launch failed, error message: " << cudaGetErrorString(status); @@ -57,7 +57,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu thrust::host_vector o_h(o_d); std::vector o_ref = cpu_reference::single_mha( q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, kv_layout, - rotary_mode); + pos_encoding_mode); size_t num_results_error_atol = 0; bool nan_detected = false; @@ -72,7 +72,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu std::cout << "num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads << ", qo_len=" << qo_len << ", kv_len=" << kv_len << ", head_dim=" << head_dim << ", causal=" << causal << ", kv_layout=" << QKVLayoutToString(kv_layout) - << ", rotary_mode=" << RotaryModeToString(rotary_mode) + << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) << ", result_accuracy=" << result_accuracy << std::endl; EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; EXPECT_FALSE(nan_detected) << "Nan detected in the result."; @@ -85,11 +85,11 @@ void TestSinglePrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduction) for (size_t num_heads : {1}) { for (size_t head_dim : {64, 128, 256}) { for (bool causal : {false, true}) { - for (size_t rotary_mode : {0, 1}) { + for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { _TestSinglePrefillKernelCorrectness( qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), - RotaryMode(rotary_mode), allow_fp16_qk_reduction); + PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } } } @@ -108,12 +108,12 @@ void TestSinglePrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction for (size_t num_kv_heads : {4, 8, 32}) { for (size_t head_dim : {64, 128, 256}) { for (bool causal : {false, true}) { - for (size_t rotary_mode : {0, 1}) { + for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { _TestSinglePrefillKernelCorrectness( qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, causal, - QKVLayout(kv_layout), RotaryMode(rotary_mode), allow_fp16_qk_reduction, rtol, - atol); + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), + allow_fp16_qk_reduction, rtol, atol); } } } @@ -130,11 +130,11 @@ void TestSinglePrefillKernelCorrectness(bool allow_fp16_qk_reduction) { for (size_t num_heads : {12}) { for (size_t head_dim : {64, 128, 256}) { for (bool causal : {false, true}) { - for (size_t rotary_mode : {0, 1}) { + for (size_t pos_encoding_mode : {0, 1}) { for (size_t kv_layout : {0, 1}) { _TestSinglePrefillKernelCorrectness( qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), - RotaryMode(rotary_mode), allow_fp16_qk_reduction); + PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); } } } diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index 17899bd4b..04b2b7001 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -52,7 +52,8 @@ template cudaError_t _SinglePrefillWithKVCacheNoLSE( DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, bool causal = true, - QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, + QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { CHECK(head_dim == 128) << "The head dimension must be 128"; @@ -64,18 +65,19 @@ cudaError_t _SinglePrefillWithKVCacheNoLSE( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {DISPATCH_GQA_GROUP_SIZE( group_size, GROUP_SIZE, - {DISPATCH_CAUSAL(causal, CAUSAL, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { - SinglePrefillWithKVCacheDispatched< - GROUP_SIZE, /*head_dim=*/128, /*layout=*/QKVLayout::kNHD, - ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL>( - q, k, v, o, tmp, /*lse=*/nullptr, num_kv_heads, qo_len, kv_len, - sm_scale, rope_scale, rope_theta, stream); - })})})}); + {DISPATCH_CAUSAL( + causal, CAUSAL, {DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, { + SinglePrefillWithKVCacheDispatched( + q, k, v, o, tmp, /*lse=*/nullptr, num_kv_heads, qo_len, kv_len, sm_scale, + rope_scale, rope_theta, stream); + })})})}); return cudaSuccess; } int _FlashInferSinglePrefillWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* tmp, - bool causal, int64_t kv_layout, int64_t rotary_mode, + bool causal, int64_t kv_layout, int64_t pos_encoding_mode, bool allow_fp16_qk_reduction, double rope_scale, double rope_theta, DLTensor* o) { // `tmp` is user-provided scratch space of at least 16MB, e.g. 4 * 1024 * 1024 float32. @@ -117,8 +119,8 @@ int _FlashInferSinglePrefillWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, D cudaError_t status = _SinglePrefillWithKVCacheNoLSE( (dtype_in*)q->data, (dtype_in*)k->data, (dtype_in*)v->data, (dtype_out*)o->data, (float*)tmp->data, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, - QKVLayout(kv_layout), RotaryMode(rotary_mode), allow_fp16_qk_reduction, rope_scale, - rope_theta, 0); + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction, + rope_scale, rope_theta, 0); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); } @@ -127,8 +129,8 @@ int _FlashInferSinglePrefillWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, D } int _FlashInferSingleDecodeWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* tmp, - int64_t kv_layout, int64_t rotary_mode, double rope_scale, - double rope_theta, DLTensor* o) { + int64_t kv_layout, int64_t pos_encoding_mode, + double rope_scale, double rope_theta, DLTensor* o) { // `tmp` is user-provided scratch space of at least 16MB, e.g. 4 * 1024 * 1024 float32. CHECK_EQ(q->device.device_type, kDLCUDA) << "The device of q matrix must be CUDA."; CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be CUDA."; @@ -166,7 +168,7 @@ int _FlashInferSingleDecodeWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, DL cudaError_t status = SingleDecodeWithKVCache( (dtype_in*)q->data, (dtype_in*)k->data, (dtype_in*)v->data, (dtype_out*)o->data, (dtype_out*)tmp->data, num_qo_heads, num_kv_heads, seq_len, head_dim, - QKVLayout(kv_layout), RotaryMode(rotary_mode), rope_scale, rope_theta, 0); + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), rope_scale, rope_theta, 0); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); } @@ -185,10 +187,11 @@ thread_local BatchPrefillHandler batch_prefill_ragged_kv_handler; template cudaError_t _BatchPrefillWithPagedKVCacheWrapper( - BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, + BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* lse, - uint32_t num_qo_heads, bool causal, RotaryMode rotary_mode, bool allow_fp16_qk_reduction, - float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { + uint32_t num_qo_heads, bool causal, PosEncodingMode pos_encoding_mode, + bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, + cudaStream_t stream) { CHECK(lse != nullptr) << "The lse buffer must be provided"; CHECK(allow_fp16_qk_reduction == false) << "The fp16 qk reduction is not supported"; CHECK(paged_kv.head_dim == 128) << "The head dimension must be 128"; @@ -198,13 +201,14 @@ cudaError_t _BatchPrefillWithPagedKVCacheWrapper( const uint32_t group_size = num_qo_heads / num_kv_heads; DISPATCH_GQA_GROUP_SIZE( group_size, GROUP_SIZE, - {DISPATCH_CAUSAL(causal, CAUSAL, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { - return BatchPrefillWithPagedKVCacheWrapperDispatched< - page_storage, kv_layout, GROUP_SIZE, /*head_dim=*/128, ROTARY_MODE, - /*allow_fp16_qk_reduction=*/false, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, sm_scale, - rope_scale, rope_theta, stream); - })})}); + {DISPATCH_CAUSAL( + causal, CAUSAL, {DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, { + return BatchPrefillWithPagedKVCacheWrapperDispatched< + page_storage, kv_layout, GROUP_SIZE, /*head_dim=*/128, POS_ENCODING_MODE, + /*allow_fp16_qk_reduction=*/false, CAUSAL, DTypeIn, DTypeOut, IdType>( + handler, q, qo_indptr, q_offset, paged_kv, o, lse, sm_scale, rope_scale, rope_theta, + stream); + })})}); return cudaSuccess; } @@ -215,11 +219,11 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q DLTensor* page_table_values, // DLTensor* last_page_len, // DLTensor* k_rope_pos_offset, // - DLTensor* q_rope_position, // + DLTensor* q_offset, // DLTensor* output, // DLTensor* lse, // int64_t causal, // - int64_t rotary_mode, // + int64_t pos_encoding_mode, // double rope_scale, // double rope_theta, double attn_score_scaling_factor = 1.0f) { @@ -232,8 +236,7 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q << "The device of page_table_values matrix must be CUDA."; CHECK_EQ(last_page_len->device.device_type, kDLCUDA) << "The device of last_page_len matrix must be CUDA."; - CHECK_EQ(q_rope_position->device.device_type, kDLCUDA) - << "The device of q_rope_position matrix must be CUDA."; + CHECK_EQ(q_offset->device.device_type, kDLCUDA) << "The device of q_offset matrix must be CUDA."; CHECK_EQ(k_rope_pos_offset->device.device_type, kDLCUDA) << "The device of k_rope_pos_offset matrix must be CUDA."; CHECK_EQ(qo_indptr->device.device_type, kDLCUDA) @@ -245,7 +248,7 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q CHECK_EQ(page_table_indptr->device.device_id, dev_id); CHECK_EQ(page_table_values->device.device_id, dev_id); CHECK_EQ(last_page_len->device.device_id, dev_id); - CHECK_EQ(q_rope_position->device.device_id, dev_id); + CHECK_EQ(q_offset->device.device_id, dev_id); CHECK_EQ(k_rope_pos_offset->device.device_id, dev_id); CHECK_EQ(qo_indptr->device.device_id, dev_id); CHECK_EQ(output->device.device_id, dev_id); @@ -253,14 +256,14 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q CHECK(q_data->dtype.lanes == 1 && pages->dtype.lanes == 1 && output->dtype.lanes == 1); CHECK(q_data->dtype.bits == pages->dtype.bits && q_data->dtype.code == pages->dtype.code); CHECK(page_table_indptr->dtype.lanes == 1 && page_table_values->dtype.lanes == 1 && - last_page_len->dtype.lanes == 1 && q_rope_position->dtype.lanes == 1 && + last_page_len->dtype.lanes == 1 && q_offset->dtype.lanes == 1 && k_rope_pos_offset->dtype.lanes == 1 && qo_indptr->dtype.lanes == 1); CHECK(page_table_indptr->dtype.bits == page_table_values->dtype.bits && page_table_indptr->dtype.bits == last_page_len->dtype.bits && page_table_indptr->dtype.bits == qo_indptr->dtype.bits && page_table_indptr->dtype.code == page_table_values->dtype.code && page_table_indptr->dtype.code == last_page_len->dtype.code && - page_table_indptr->dtype.code == q_rope_position->dtype.code && + page_table_indptr->dtype.code == q_offset->dtype.code && page_table_indptr->dtype.code == k_rope_pos_offset->dtype.code && page_table_indptr->dtype.code == qo_indptr->dtype.code); @@ -286,8 +289,8 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q CHECK_EQ(q_data->shape[2], nfeat); CHECK_EQ(output->shape[1], nhead_qo); CHECK_EQ(output->shape[2], nfeat); - CHECK_EQ(q_rope_position->ndim, 1); - CHECK_EQ(q_rope_position->shape[0], q_data->shape[0]); + CHECK_EQ(q_offset->ndim, 1); + CHECK_EQ(q_offset->shape[0], q_data->shape[0]); CHECK_EQ(k_rope_pos_offset->ndim, 1); CHECK_EQ(k_rope_pos_offset->shape[0], num_total_seqs); @@ -311,11 +314,11 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q dtype_idx>( &batch_prefill_paged_kv_handlers[handler_id], static_cast(q_data->data), static_cast(qo_indptr->data), - static_cast(q_rope_position->data), cache, + static_cast(q_offset->data), cache, static_cast(output->data), /*lse=*/static_cast(lse->data), nhead_qo, - /*causal=*/causal, RotaryMode(rotary_mode), /*allow_fp16_qk_reduction=*/false, - sm_scale, rope_scale, rope_theta, + /*causal=*/causal, PosEncodingMode(pos_encoding_mode), + /*allow_fp16_qk_reduction=*/false, sm_scale, rope_scale, rope_theta, /*stream=*/0); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); @@ -349,15 +352,15 @@ thread_local BatchDecodeHandler batch_decode_handlers[max_num_handlers]; void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_data, DLTensor* pages, - DLTensor* page_table_indptr, // - DLTensor* page_table_values, // - DLTensor* last_page_len, // - DLTensor* k_rope_pos_offset, // - DLTensor* q_rope_position, // - DLTensor* output, // - DLTensor* lse, // - int64_t rotary_mode = 0, // - double rope_scale = 1.0f, // + DLTensor* page_table_indptr, // + DLTensor* page_table_values, // + DLTensor* last_page_len, // + DLTensor* k_rope_pos_offset, // + DLTensor* q_offset, // + DLTensor* output, // + DLTensor* lse, // + int64_t pos_encoding_mode = 0, // + double rope_scale = 1.0f, // double rope_theta = 1e4, double attn_score_scaling_factor = 1.0f) { CHECK_LT(handler_id, max_num_handlers) << "The handler id must be less than " << max_num_handlers; @@ -369,8 +372,7 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ << "The device of page_table_values matrix must be CUDA."; CHECK_EQ(last_page_len->device.device_type, kDLCUDA) << "The device of last_page_len matrix must be CUDA."; - CHECK_EQ(q_rope_position->device.device_type, kDLCUDA) - << "The device of q_rope_position matrix must be CUDA."; + CHECK_EQ(q_offset->device.device_type, kDLCUDA) << "The device of q_offset matrix must be CUDA."; CHECK_EQ(k_rope_pos_offset->device.device_type, kDLCUDA) << "The device of k_rope_pos_offset matrix must be CUDA."; CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA."; @@ -380,20 +382,20 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ CHECK_EQ(page_table_indptr->device.device_id, dev_id); CHECK_EQ(page_table_values->device.device_id, dev_id); CHECK_EQ(last_page_len->device.device_id, dev_id); - CHECK_EQ(q_rope_position->device.device_id, dev_id); + CHECK_EQ(q_offset->device.device_id, dev_id); CHECK_EQ(k_rope_pos_offset->device.device_id, dev_id); CHECK_EQ(output->device.device_id, dev_id); CHECK(q_data->dtype.lanes == 1 && pages->dtype.lanes == 1 && output->dtype.lanes == 1); CHECK(q_data->dtype.bits == pages->dtype.bits && q_data->dtype.code == pages->dtype.code); CHECK(page_table_indptr->dtype.lanes == 1 && page_table_values->dtype.lanes == 1 && - last_page_len->dtype.lanes == 1 && q_rope_position->dtype.lanes == 1 && + last_page_len->dtype.lanes == 1 && q_offset->dtype.lanes == 1 && k_rope_pos_offset->dtype.lanes == 1); CHECK(page_table_indptr->dtype.bits == page_table_values->dtype.bits && page_table_indptr->dtype.bits == last_page_len->dtype.bits && page_table_indptr->dtype.code == page_table_values->dtype.code && page_table_indptr->dtype.code == last_page_len->dtype.code && - page_table_indptr->dtype.code == q_rope_position->dtype.code && + page_table_indptr->dtype.code == q_offset->dtype.code && page_table_indptr->dtype.code == k_rope_pos_offset->dtype.code); CHECK_EQ(pages->ndim, 5); @@ -417,8 +419,8 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ int64_t nhead_qo = q_data->shape[1]; CHECK_EQ(output->shape[1], nhead_qo); CHECK_EQ(output->shape[2], nfeat); - CHECK_EQ(q_rope_position->ndim, 1); - CHECK_EQ(q_rope_position->shape[0], num_total_seqs); + CHECK_EQ(q_offset->ndim, 1); + CHECK_EQ(q_offset->shape[0], num_total_seqs); CHECK_EQ(k_rope_pos_offset->ndim, 1); CHECK_EQ(k_rope_pos_offset->shape[0], num_total_seqs); @@ -440,10 +442,10 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( &batch_decode_handlers[handler_id], static_cast(q_data->data), - static_cast(q_rope_position->data), cache, + static_cast(q_offset->data), cache, static_cast(output->data), - /*lse=*/static_cast(lse->data), nhead_qo, RotaryMode(rotary_mode), sm_scale, - rope_scale, rope_theta, + /*lse=*/static_cast(lse->data), nhead_qo, + PosEncodingMode(pos_encoding_mode), sm_scale, rope_scale, rope_theta, /*stream=*/0); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); @@ -454,7 +456,7 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward( int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* page_table_indptr, DLTensor* last_page_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, - int64_t page_size, int64_t rotary_mode) { + int64_t page_size, int64_t pos_encoding_mode) { CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; CHECK_LT(handler_idx, max_num_handlers) @@ -472,7 +474,7 @@ void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward( static_cast(workspace_buffer->data), workspace_size_in_bytes, static_cast(page_table_indptr->data), static_cast(last_page_len->data), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, RotaryMode(rotary_mode)); + num_kv_heads, head_dim, page_size, PosEncodingMode(pos_encoding_mode)); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer decode BeginForward error " << cudaGetErrorString(status); } @@ -491,9 +493,9 @@ void _FlashInferAttentionDecodeWithPagedKVCacheEndForward(int64_t handler_id) { template cudaError_t _BatchPrefillWithRaggedKVCacheWrapper( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, - IdType* kv_indptr, IdType* q_rope_position_map, IdType* k_rope_pos_offset, DTypeOut* o, - float* lse, const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads, - const uint32_t head_dim, bool causal, QKVLayout kv_layout, RotaryMode rotary_mode, + IdType* kv_indptr, IdType* q_offset_map, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, + const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads, + const uint32_t head_dim, bool causal, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, bool allow_fp16_qk_reduction, const float sm_scale, const float rope_scale, const float rope_theta, cudaStream_t stream) { CHECK(lse != nullptr) << "The lse buffer must be provided"; @@ -503,22 +505,22 @@ cudaError_t _BatchPrefillWithRaggedKVCacheWrapper( DISPATCH_GQA_GROUP_SIZE( num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_CAUSAL(causal, CAUSAL, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { - return BatchPrefillWithRaggedKVCacheWrapperDispatched< - GROUP_SIZE, /*head_dim=*/128, /*layout=*/QKVLayout::kNHD, ROTARY_MODE, - /*allow_fp16_qk_reduction=*/false, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, k, v, kv_indptr, q_rope_position_map, - k_rope_pos_offset, o, lse, batch_size, num_kv_heads, sm_scale, - rope_scale, rope_theta, stream); - })})}); + {DISPATCH_CAUSAL( + causal, CAUSAL, {DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, { + return BatchPrefillWithRaggedKVCacheWrapperDispatched< + GROUP_SIZE, /*head_dim=*/128, /*layout=*/QKVLayout::kNHD, POS_ENCODING_MODE, + /*allow_fp16_qk_reduction=*/false, CAUSAL, DTypeIn, DTypeOut, IdType>( + handler, q, qo_indptr, k, v, kv_indptr, q_offset_map, k_rope_pos_offset, o, lse, + batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream); + })})}); return cudaSuccess; } void _FlashInferAttentionPrefillWithRaggedKVCache( DLTensor* q_data, DLTensor* qo_indptr, DLTensor* k_data, DLTensor* v_data, DLTensor* kv_indptr, - DLTensor* q_rope_position_map, DLTensor* k_rope_pos_offset, DLTensor* output, DLTensor* lse, - int64_t causal = 1, int64_t rotary_mode = 0, double rope_scale = 1.0f, double rope_theta = 1e4, - double attn_score_scaling_factor = 1.0f) { + DLTensor* q_offset_map, DLTensor* k_rope_pos_offset, DLTensor* output, DLTensor* lse, + int64_t causal = 1, int64_t pos_encoding_mode = 0, double rope_scale = 1.0f, + double rope_theta = 1e4, double attn_score_scaling_factor = 1.0f) { CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA."; CHECK_EQ(qo_indptr->device.device_type, kDLCUDA) << "The device of qo_indptr must be CUDA."; CHECK_EQ(k_data->device.device_type, kDLCUDA) << "The device of k_data must be CUDA."; @@ -526,8 +528,7 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( CHECK_EQ(kv_indptr->device.device_type, kDLCUDA) << "The device of kv_indptr must be CUDA."; CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA."; CHECK_EQ(lse->device.device_type, kDLCUDA) << "The lse of output must be CUDA."; - CHECK_EQ(q_rope_position_map->device.device_type, kDLCUDA) - << "The device of q_rope_position_map must be CUDA."; + CHECK_EQ(q_offset_map->device.device_type, kDLCUDA) << "The device of q_offset_map must be CUDA."; CHECK_EQ(k_rope_pos_offset->device.device_type, kDLCUDA) << "The device of k_rope_pos_offset must be CUDA."; @@ -538,19 +539,19 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( CHECK_EQ(kv_indptr->device.device_id, dev_id); CHECK_EQ(output->device.device_id, dev_id); CHECK_EQ(lse->device.device_id, dev_id); - CHECK_EQ(q_rope_position_map->device.device_id, dev_id); + CHECK_EQ(q_offset_map->device.device_id, dev_id); CHECK_EQ(k_rope_pos_offset->device.device_id, dev_id); CHECK(q_data->dtype.lanes == 1 && qo_indptr->dtype.lanes == 1 && k_data->dtype.lanes == 1 && v_data->dtype.lanes == 1 && kv_indptr->dtype.lanes == 1 && output->dtype.lanes == 1 && - lse->dtype.lanes == 1 && q_rope_position_map->dtype.lanes == 1 && + lse->dtype.lanes == 1 && q_offset_map->dtype.lanes == 1 && k_rope_pos_offset->dtype.lanes == 1); CHECK(q_data->dtype.bits == k_data->dtype.bits && q_data->dtype.code == v_data->dtype.code); CHECK(qo_indptr->dtype.bits == kv_indptr->dtype.bits); CHECK(lse->dtype.bits == 32); CHECK(q_data->dtype.code == k_data->dtype.code && q_data->dtype.code == v_data->dtype.code); CHECK(qo_indptr->dtype.code == kv_indptr->dtype.code); - CHECK(q_rope_position_map->dtype.code == kv_indptr->dtype.code); + CHECK(q_offset_map->dtype.code == kv_indptr->dtype.code); CHECK(k_rope_pos_offset->dtype.code == kv_indptr->dtype.code); CHECK(lse->dtype.code == kDLFloat); @@ -577,8 +578,8 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( int64_t batch_size = qo_indptr->shape[0] - 1; CHECK_EQ(kv_indptr->shape[0], batch_size + 1); - CHECK_EQ(q_rope_position_map->ndim, 1); - CHECK_EQ(q_rope_position_map->shape[0], q_data->shape[0]); + CHECK_EQ(q_offset_map->ndim, 1); + CHECK_EQ(q_offset_map->shape[0], q_data->shape[0]); CHECK_EQ(k_rope_pos_offset->ndim, 1); CHECK_EQ(k_rope_pos_offset->shape[0], batch_size); @@ -593,11 +594,11 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( &batch_prefill_ragged_kv_handler, static_cast(q_data->data), static_cast(qo_indptr->data), static_cast(k_data->data), static_cast(v_data->data), static_cast(kv_indptr->data), - static_cast(q_rope_position_map->data), + static_cast(q_offset_map->data), static_cast(k_rope_pos_offset->data), static_cast(output->data), /*lse=*/static_cast(lse->data), batch_size, nhead_qo, nhead_kv, nfeat, - /*causal=*/bool(causal), QKVLayout::kNHD, RotaryMode(rotary_mode), + /*causal=*/bool(causal), QKVLayout::kNHD, PosEncodingMode(pos_encoding_mode), /*allow_fp16_qk_reduction=*/false, sm_scale, rope_scale, rope_theta, /*sm_scale=*/0); })})})