diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 690559ee265e9..7782ef30288e4 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -57,7 +57,7 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) "-mavx512dq") find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) - if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) + if (AVX512BF16_FOUND AND ENABLE_AVX512BF16) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 8367093325314..7d93eb9bb9126 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -2,7 +2,7 @@ namespace { -template +template struct KernelVecType { using q_load_vec_type = void; using q_vec_type = void; @@ -22,6 +22,16 @@ struct KernelVecType { using v_load_vec_type = vec_op::FP32Vec16; }; +template <> +struct KernelVecType { + using q_load_vec_type = vec_op::FP32Vec16; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::FP8Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP8Vec16; +}; + #ifdef __AVX512BF16__ template <> struct KernelVecType { @@ -32,6 +42,16 @@ struct KernelVecType { using qk_acc_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::BF16Vec16; }; + +template <> +struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::BF16Vec32; + using k_load_vec_type = vec_op::FP8Vec16; + using k_vec_type = vec_op::BF16Vec32; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP8Vec16; +}; #else template <> struct KernelVecType { @@ -42,6 +62,16 @@ struct KernelVecType { using qk_acc_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::BF16Vec16; }; + +template <> +struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec16; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::FP8Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP8Vec16; +}; #endif template @@ -121,28 +151,27 @@ FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, } } -template +template struct reduceQKBlockKernel { - using q_load_vec_type = typename KernelVecType::q_load_vec_type; - using q_vec_type = typename KernelVecType::q_vec_type; - using k_load_vec_type = typename KernelVecType::k_load_vec_type; - using k_vec_type = typename KernelVecType::k_vec_type; - using qk_acc_vec_type = typename KernelVecType::qk_acc_vec_type; + using q_load_vec_type = + typename KernelVecType::q_load_vec_type; + using q_vec_type = typename KernelVecType::q_vec_type; + using k_load_vec_type = + typename KernelVecType::k_load_vec_type; + using k_vec_type = typename KernelVecType::k_vec_type; + using qk_acc_vec_type = + typename KernelVecType::qk_acc_vec_type; constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x; constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP; constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4; - static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4); - static_assert(k_load_vec_type::get_elem_num() % x == 0); - static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); - FORCE_INLINE static void call(const scalar_t* __restrict__ q, - const scalar_t* __restrict__ k_block, + const cache_t* __restrict__ k_block, float* __restrict__ logits, float scale, const int token_num) { const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; - qk_acc_vec_type group_accums[MAX_GROUP_NUM]; if (token_num == BLOCK_SIZE) { for (int q_offset = 0; q_offset < HEAD_SIZE; @@ -200,10 +229,11 @@ struct reduceQKBlockKernel { }; template -FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, + int HEAD_PARTITION_SIZE, typename cache_t, typename acc_t> +FORCE_INLINE void reduceValueBlock(const float* prob, const cache_t* v_block, acc_t&& acc) { - using v_load_vec_type = typename KernelVecType::v_load_vec_type; + using v_load_vec_type = + typename KernelVecType::v_load_vec_type; constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); static_assert(BLOCK_SIZE == ELEM_NUM); vec_op::FP32Vec16 prob_vec(prob); @@ -218,15 +248,16 @@ FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, // Paged attention v1 namespace { -template +template struct paged_attention_v1_impl { static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, // max_num_blocks_per_seq] @@ -235,7 +266,7 @@ struct paged_attention_v1_impl { const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, const int num_seqs, const int num_heads) { - constexpr int x = 16 / sizeof(scalar_t); + constexpr int x = 16 / sizeof(cache_t); const int num_queries_per_kv = num_heads / num_kv_heads; static_assert(BLOCK_SIZE == 16); @@ -269,15 +300,18 @@ struct paged_attention_v1_impl { // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = + const cache_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; float* __restrict__ head_block_logits = thread_block_logits + block_idx * BLOCK_SIZE; - reduceQKBlockKernel::call( - q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, - block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); + reduceQKBlockKernel::call(q_vec_ptr, k_block_cache_ptr, + head_block_logits, scale, + block_idx == block_num - 1 + ? last_block_token_num + : BLOCK_SIZE); } // Compute softmax @@ -303,18 +337,18 @@ struct paged_attention_v1_impl { const int64_t physical_block_idx = seq_block_table[block_idx]; const float* __restrict__ prob_vec_ptr = thread_block_logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = + const cache_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; reduceValueBlock( + head_elem_num_per_partition, cache_t>( prob_vec_ptr, v_block_cache_ptr, accums); if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = + const cache_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -340,14 +374,21 @@ struct paged_attention_v1_impl { } }; -#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v1_impl::call( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ - num_heads); - -template +#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v1_impl::call(out_ptr, query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + num_kv_heads, scale, \ + block_tables_ptr, \ + seq_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, \ + kv_block_stride, \ + kv_head_stride, num_seqs, \ + num_heads); + +template void paged_attention_v1_impl_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, @@ -369,8 +410,8 @@ void paged_attention_v1_impl_launcher( T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); @@ -402,19 +443,19 @@ void paged_attention_v1_impl_launcher( } } -#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_impl_launcher( \ +#define CALL_V1_KERNEL_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + paged_attention_v1_impl_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ seq_lens, max_seq_len, alibi_slopes); -#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V1_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ + switch (block_size) { \ + case 16: \ + CALL_V1_KERNEL_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } } // namespace @@ -428,33 +469,48 @@ void paged_attention_v1( const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); - TORCH_CHECK(blocksparse_vert_stride <= 1, - "CPU backend does not support blocksparse attention yet."); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", - [&] { - CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) - CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); - CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl) - }); + if (kv_cache_dtype == "auto") { + TORCH_CHECK(blocksparse_vert_stride <= 1, + "CPU backend does not support blocksparse attention yet."); + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "paged_attention_v1_impl", [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) + CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t, scalar_t, false); + CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl) + }); + } else if (kv_cache_dtype == "fp8") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(float, cpu_fp8, true); + } else if (query.dtype() == at::ScalarType::Half) { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(c10::BFloat16, cpu_fp8, true); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } } // Paged attention v2 namespace { -template +template struct paged_attention_v2_impl { static void call( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - float* __restrict__ exp_sums, // [num_seqs, num_heads, - // max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, // max_num_blocks_per_seq] @@ -463,7 +519,7 @@ struct paged_attention_v2_impl { const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, const int num_seqs, const int num_heads, const int max_num_partitions) { - constexpr int x = 16 / sizeof(scalar_t); + constexpr int x = 16 / sizeof(cache_t); const int num_queries_per_kv = num_heads / num_kv_heads; static_assert(BLOCK_SIZE == 16); @@ -501,15 +557,18 @@ struct paged_attention_v2_impl { // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t* __restrict__ k_block_cache_ptr = + const cache_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; float* __restrict__ head_block_logits = logits + block_idx * BLOCK_SIZE; - reduceQKBlockKernel::call( - q_vec_ptr, k_block_cache_ptr, head_block_logits, scale, - block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE); + reduceQKBlockKernel::call(q_vec_ptr, k_block_cache_ptr, + head_block_logits, scale, + block_idx == block_num - 1 + ? last_block_token_num + : BLOCK_SIZE); } std::pair max_and_sum; @@ -552,18 +611,18 @@ struct paged_attention_v2_impl { const int64_t physical_block_idx = seq_block_table[block_idx]; const float* __restrict__ prob_vec_ptr = logits + block_idx * BLOCK_SIZE; - const scalar_t* __restrict__ v_block_cache_ptr = + const cache_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; reduceValueBlock( + head_elem_num_per_partition, cache_t>( prob_vec_ptr, v_block_cache_ptr, accums); if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t* __restrict__ next_v_block_cache_ptr = + const cache_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -651,15 +710,21 @@ struct paged_attention_v2_impl { } }; -#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v2_impl::call( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ - key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, num_seqs, num_heads, \ - max_num_partitions); - -template +#define LAUNCH_V2_ATTENTION_KERNEL(T, CACHE_T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_impl::call(out_ptr, exp_sums_ptr, \ + max_logits_ptr, tmp_out_ptr, \ + query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, seq_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, \ + num_seqs, num_heads, \ + max_num_partitions); + +template void paged_attention_v2_impl_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -686,32 +751,32 @@ void paged_attention_v2_impl_launcher( float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { case 64: - LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + LAUNCH_V2_ATTENTION_KERNEL(T, CACHE_T, 64, BLOCK_SIZE); break; case 80: - LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + LAUNCH_V2_ATTENTION_KERNEL(T, CACHE_T, 80, BLOCK_SIZE); break; case 96: - LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + LAUNCH_V2_ATTENTION_KERNEL(T, CACHE_T, 96, BLOCK_SIZE); break; case 112: - LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + LAUNCH_V2_ATTENTION_KERNEL(T, CACHE_T, 112, BLOCK_SIZE); break; case 128: - LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + LAUNCH_V2_ATTENTION_KERNEL(T, CACHE_T, 128, BLOCK_SIZE); break; case 192: - LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); + LAUNCH_V2_ATTENTION_KERNEL(T, CACHE_T, 192, BLOCK_SIZE); break; case 256: - LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + LAUNCH_V2_ATTENTION_KERNEL(T, CACHE_T, 256, BLOCK_SIZE); break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); @@ -719,16 +784,16 @@ void paged_attention_v2_impl_launcher( } } -#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_impl_launcher( \ +#define CALL_V2_KERNEL_LAUNCHER(T, CACHE_T, BLOCK_SIZE) \ + paged_attention_v2_impl_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ alibi_slopes); -#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ +#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T, CACHE_T) \ switch (block_size) { \ case 16: \ - CALL_V2_KERNEL_LAUNCHER(T, 16); \ + CALL_V2_KERNEL_LAUNCHER(T, CACHE_T, 16); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -749,10 +814,19 @@ void paged_attention_v2( TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); - VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", - [&] { - CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) - CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); - CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) - }); + if (kv_cache_dtype == "auto") { + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "paged_attention_v2_impl", [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) + CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t, scalar_t); + CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) + }); + } else if (kv_cache_dtype == "fp8") { + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "paged_attention_v2_impl", [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) + CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t, cpu_fp8); + CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) + }); + } } diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 2b5c3bd6ee70b..3efe37ccc1f14 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -32,13 +32,33 @@ void copy_blocks_cpu_impl(std::vector const& key_caches, } } -template -void reshape_and_cache_cpu_impl( - const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, - scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, - const int64_t* __restrict__ slot_mapping, const int num_tokens, - const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x) { +template +cache_t assign_cache_value(const scalar_t* src) { + return *src; +} + +template <> +uint8_t assign_cache_value(const float* src) { + uint8_t res = cast_fp32x1_to_fp8x1(*src); + return res; +} + +template <> +uint8_t assign_cache_value(const int16_t* src) { + uint8_t res = cast_bf16x1_to_fp8x1(*src); + return res; +} + +template +void reshape_and_cache_cpu_impl(const scalar_t* __restrict__ key, + const scalar_t* __restrict__ value, + cache_t* __restrict__ key_cache, + cache_t* __restrict__ value_cache, + const int64_t* __restrict__ slot_mapping, + const int num_tokens, const int key_stride, + const int value_stride, const int num_heads, + const int head_size, const int block_size, + const int kv_cache_stride, const int x) { const int block_elem_num = num_heads * head_size * block_size; #pragma omp parallel for collapse(2) @@ -53,19 +73,20 @@ void reshape_and_cache_cpu_impl( const scalar_t* src_value_head_ptr = value + src_value_head_idx; const int64_t block_index = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; - scalar_t* target_key_head_ptr = key_cache + - block_elem_num * block_index + - head_idx * block_size * head_size; - scalar_t* target_value_head_ptr = value_cache + - block_elem_num * block_index + - head_idx * block_size * head_size; + cache_t* target_key_head_ptr = key_cache + + kv_cache_stride * block_index + + head_idx * block_size * head_size; + cache_t* target_value_head_ptr = value_cache + + kv_cache_stride * block_index + + head_idx * block_size * head_size; for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) { const int64_t target_offset = src_key_idx * block_size + block_offset * x; for (int i = 0; i < x; ++i) { target_key_head_ptr[target_offset + i] = - src_key_head_ptr[src_key_idx + i]; + assign_cache_value(src_key_head_ptr + + src_key_idx + i); } } @@ -74,7 +95,8 @@ void reshape_and_cache_cpu_impl( const int64_t target_offset = src_value_idx * block_size + block_offset; target_value_head_ptr[target_offset] = - src_value_head_ptr[src_value_idx]; + assign_cache_value(src_value_head_ptr + + src_value_idx); } } } @@ -104,6 +126,17 @@ void copy_blocks(std::vector const& key_caches, }); } +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ + CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) \ + reshape_and_cache_cpu_impl( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), num_tokens, key_stride, value_stride, \ + num_heads, head_size, block_size, kv_cache_stride, x); \ + CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) + void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, @@ -115,20 +148,30 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, int head_size = key.size(2); int block_size = key_cache.size(3); int x = key_cache.size(4); + int kv_cache_stride = key_cache.stride(0); int key_stride = key.stride(0); int value_stride = value.stride(0); - VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { - CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) - reshape_and_cache_cpu_impl( - key.data_ptr(), value.data_ptr(), - key_cache.data_ptr(), value_cache.data_ptr(), - slot_mapping.data_ptr(), num_tokens, key_stride, - value_stride, num_heads, head_size, block_size, x); - CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) - }); + if (kv_cache_dtype == "auto") { + if (key.dtype() == at::ScalarType::Float) { + CALL_RESHAPE_AND_CACHE(float, float, false); + } else if (key.dtype() == at::ScalarType::Half) { + TORCH_CHECK(false, "Unsupported data type: Half"); + } else if (key.dtype() == at::ScalarType::BFloat16) { + CALL_RESHAPE_AND_CACHE(int16_t, int16_t, false); + } + } else if (kv_cache_dtype == "fp8") { + if (key.dtype() == at::ScalarType::Float) { + CALL_RESHAPE_AND_CACHE(float, uint8_t, true); + } else if (key.dtype() == at::ScalarType::Half) { + TORCH_CHECK(false, "Unsupported data type: Half"); + } else if (key.dtype() == at::ScalarType::BFloat16) { + CALL_RESHAPE_AND_CACHE(int16_t, uint8_t, true); + } + } else { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } } void swap_blocks(torch::Tensor& src, torch::Tensor& dst, diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index 0213be09105ed..eb57c930522b7 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -1,4 +1,3 @@ - #ifndef CPU_TYPES_HPP #define CPU_TYPES_HPP diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index f50620a5287d4..9d65537c88754 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -5,6 +5,10 @@ #include #include +#include "fp8_utils.h" + +typedef uint8_t cpu_fp8; + #ifndef __AVX2__ static_assert(false, "AVX2 must be supported for the current implementation."); #endif @@ -50,6 +54,19 @@ template struct Vec { struct FP32Vec8; struct FP32Vec16; +struct FP8Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m128 reg; + cpu_fp8 values[VEC_ELEM_NUM]; + }; + __m128 reg; + + explicit FP8Vec16() : reg(_mm_set1_ps(0)) {} + explicit FP8Vec16(const cpu_fp8 *ptr) : reg((__m128)_mm_loadu_epi8(ptr)) {} + +}; + #ifdef __AVX512FP16__ struct FP16Vec8 : public Vec { constexpr static int VEC_ELEM_NUM = 8; @@ -279,6 +296,8 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} + explicit FP32Vec16(const FP8Vec16 &data) : reg(cast_fp8x16_to_fp32x16((__m128)data.reg)) {} + explicit FP32Vec16(const FP32Vec4 &data) : reg((__m512)_mm512_inserti32x4( _mm512_inserti32x4( diff --git a/csrc/cpu/fp8_utils.h b/csrc/cpu/fp8_utils.h new file mode 100644 index 0000000000000..16d8609c64f0c --- /dev/null +++ b/csrc/cpu/fp8_utils.h @@ -0,0 +1,256 @@ +#ifndef FP8_UTILS_H +#define FP8_UTILS_H + +#include +#include +#include +#include + +static inline __m512i _mm512_cvte5m2_fp16(__m256i a) { + return _mm512_slli_epi16(_mm512_cvtepi8_epi16(a), 8); +} + +static inline __m256i _mm256_cvte5m2_fp16(__m128i a) { + return _mm256_slli_epi16(_mm256_cvtepi8_epi16(a), 8); +} + +static inline __m256i _mm256_cvt2fp16_e5m2(__m256i a, __m256i b) { + const __m512i vnaninf = _mm512_set1_epi16(0x7c00), + vrneadd = _mm512_set1_epi16(0x007f); + const __m512i vfixup = _mm512_set1_epi16(0x0001), + vfixupmask = _mm512_set1_epi16(0x0100); + /* b: lower half, a : upper half */ + const __m512i a_ = _mm512_inserti64x4( + _mm512_inserti64x4(_mm512_setzero_si512(), b, 0), a, 1); + const __mmask32 maska1_ = _mm512_cmp_epi16_mask(_mm512_and_si512(a_, vnaninf), + vnaninf, _MM_CMPINT_NE); + const __mmask32 maska2_ = _mm512_cmp_epi16_mask( + _mm512_and_si512(a_, vfixupmask), vfixupmask, _MM_CMPINT_EQ); + __m512i a_rne_ = _mm512_mask_add_epi16( + a_, maska1_, a_, + _mm512_mask_add_epi16(vrneadd, maska2_, vrneadd, vfixup)); + return _mm512_cvtepi16_epi8(_mm512_srli_epi16(a_rne_, 8)); +} + +static inline __m256i _mm256_cvt2fp16_e5m2_noINF(__m256i a, __m256i b) { + const __m512i vnaninf = _mm512_set1_epi16(0x7c00); + const __m512i vrneadd = _mm512_set1_epi16(0x007f); + const __m512i vfixup = _mm512_set1_epi16(0x0001); + const __m512i vfixupmask = _mm512_set1_epi16(0x0100); + /* use a non-standard exponent offset = 16, */ + const __m512i vExp_fp16 = _mm512_set1_epi16(0x000F); + const __m512i vExp_e5m2 = _mm512_set1_epi16(0x0010); + const __m512i vsMant = _mm512_set1_epi16(0x83FF); + /* Exponent Offset = 16, reclaim inf/NaN */ + const __m512i vsatuval = + _mm512_set1_epi16(0x7F00); /* 2^15*1.11 a.k.a 57344.0, largest value */ + const __m512i vinfval = _mm512_set1_epi16(0x8000); /* -0.0 as INF */ + const __m512i a_ = _mm512_inserti64x4( + _mm512_inserti64x4(_mm512_setzero_si512(), b, 0), a, 1); + const __mmask32 maska1_ = _mm512_cmp_epi16_mask(_mm512_and_si512(a_, vnaninf), + vnaninf, _MM_CMPINT_NE); + const __mmask32 maska2_ = _mm512_cmp_epi16_mask( + _mm512_and_si512(a_, vfixupmask), vfixupmask, _MM_CMPINT_EQ); + const __mmask32 maska3_ = + _mm512_cmp_epi16_mask(_mm512_and_si512(a_, _mm512_set1_epi16(0x7FFF)), + vsatuval, _MM_CMPINT_NLE); + __m512i vExp_ = _mm512_sub_epi16( + _mm512_srli_epi16(_mm512_and_si512(a_, vnaninf), 10), vExp_fp16); + vExp_ = _mm512_slli_epi16(_mm512_add_epi16(vExp_, vExp_e5m2), 10); + __m512i a_rne_ = _mm512_or_si512(vExp_, _mm512_and_si512(a_, vsMant)); + a_rne_ = _mm512_mask_add_epi16( + a_rne_, maska1_, a_rne_, + _mm512_mask_add_epi16(vrneadd, maska2_, vrneadd, vfixup)); + a_rne_ = _mm512_mask_mov_epi16( + a_rne_, maska3_, + _mm512_or_si512(_mm512_and_si512(a_rne_, vinfval), vsatuval)); + a_rne_ = _mm512_mask_mov_epi16(a_rne_, ~maska1_, vinfval); + return _mm512_cvtepi16_epi8(_mm512_srli_epi16(a_rne_, 8)); +} + +static inline __m512i _mm512_cvte5m2_noinf_fp16(__m256i a) { + const __m512i vExp_fp16 = _mm512_set1_epi16(0x000F); + const __m512i vExp_e5m2 = _mm512_set1_epi16(0x0010); + const __m512i vsMant = _mm512_set1_epi16(0x83FF); + const __m512i vnaninf = _mm512_set1_epi16(0x8000); /* -0.0 as INF */ + const __m512i vinfval = _mm512_set1_epi16(0x7c00); + __m512i a_ = _mm512_slli_epi16(_mm512_cvtepi8_epi16(a), 8); + const __mmask32 mask1_ = _mm512_cmp_epi16_mask(a_, vnaninf, _MM_CMPINT_EQ); + __m512i vExp_ = _mm512_sub_epi16( + _mm512_srli_epi16(_mm512_and_si512(a_, vinfval), 10), vExp_e5m2); + vExp_ = _mm512_slli_epi16(_mm512_add_epi16(vExp_, vExp_fp16), 10); + a_ = _mm512_or_si512(vExp_, _mm512_and_si512(a_, vsMant)); + return _mm512_mask_mov_epi16(a_, mask1_, vinfval); +} + +static inline void cvt_fp16_e5m2_noINF_rne_intrinsic( + const short* __restrict__ in, unsigned char* out, int size) { +#pragma omp parallel for + for (int i = 0; i < size; i += 32) { + __m256i bh_ = _mm256_lddqu_si256((__m256i*)&in[i]); + __m256i ah_ = _mm256_lddqu_si256((__m256i*)&in[i + 16]); + _mm256_storeu_si256((__m256i*)&out[i], _mm256_cvt2fp16_e5m2(ah_, bh_)); + } +} + +static inline void cvt_fp32_e5m2_noinf_rne_intrinsic( + const float* __restrict__ in, float* out, int size, float scale) { +#pragma omp parallel for + for (int i = 0; i < size; i += 32) { + __m512 b = _mm512_loadu_ps(&in[i]); + __m512 a = _mm512_loadu_ps(&in[i + 16]); + __m256i ah_ = + _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i bh_ = + _mm512_cvtps_ph(b, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m512i a_rne_ = + _mm512_cvte5m2_noinf_fp16(_mm256_cvt2fp16_e5m2_noINF(ah_, bh_)); + bh_ = _mm512_extracti64x4_epi64(a_rne_, 0); + ah_ = _mm512_extracti64x4_epi64(a_rne_, 1); + b = _mm512_cvtph_ps(bh_); + a = _mm512_cvtph_ps(ah_); + _mm512_storeu_ps(&out[i], b); + _mm512_storeu_ps(&out[i + 16], a); + } +} + +static inline __m512i cast_fp8x32_to_fp16x32(__m256i a) { + return _mm512_cvte5m2_fp16(a); +} + +static inline __m256i cast_fp16x16x2_to_fp8x32(__m256i a, __m256i b) { + return _mm256_cvt2fp16_e5m2(a, b); +} + +static inline void cast_fp8xn_to_fp16xn(const char* __restrict__ in, + unsigned short* out, int n) { +#pragma omp parallel for + for (int i = 0; i < n; i += 32) { + __m256i a = _mm256_loadu_si256((const __m256i*)&in[i]); + __m512i b = cast_fp8x32_to_fp16x32(a); + __m512i* out_p = (__m512i*)(&out[i]); + *out_p = b; + } +} + +static inline void cast_fp16xn_to_fp8xn(const short* __restrict__ in, + unsigned char* out, int n) { + cvt_fp16_e5m2_noINF_rne_intrinsic(in, out, n); +} + +static inline void cast_fp32xn_to_fp8xn(const float* __restrict__ in, + float* out, int n) { + cvt_fp32_e5m2_noinf_rne_intrinsic(in, out, n, 0); +} + +static inline uint8_t cast_bf16x1_to_fp8x1(int16_t bf16bits) { + // Define the FP32 bias and the target FP8 bias + const int fp16Bias = 127; + const int fp8Bias = 15; + uint8_t sign = (bf16bits >> 15) & 0x01; + int8_t shift = (bf16bits >> 7) & 0xFF; + if (shift == (int8_t)0xFF) { + return (sign << 7) | 0x7F; + } + if (shift <= (int8_t)0x70 && shift >= (int8_t)0x91) { + return (sign << 7); + } + + int8_t exponent = shift - fp16Bias; + uint16_t mantissa = bf16bits & 0x007F; + + // Adjust the exponent and mantissa for FP8 + exponent += fp8Bias; + + // Handle special cases and rounding (not shown for brevity) + // Assemble the FP8 value (manual bit manipulation) + uint8_t fp8 = + (sign << 7) | ((exponent & 0x1F) << 2) | ((mantissa >> 5) & 0x03); + + return fp8; +} + +static inline uint8_t cast_fp32x1_to_fp8x1(float fp32) { + // Define the FP32 bias and the target FP8 bias + const int fp32Bias = 127; + const int fp8Bias = 15; + + // Use intrinsics to extract the bits from the FP32 value + __m128 fp32Vector = _mm_set_ss(fp32); + int fp32Bits = + _mm_extract_ps(fp32Vector, 0); // Extract the bits into an integer + + // Extract sign, exponent, and mantissa from FP32 + uint8_t sign = (fp32Bits >> 31) & 0x01; + int8_t shift = (fp32Bits >> 23) & 0xFF; + if (shift == (int8_t)0xFF) { + return (sign << 7) | 0x7F; + } + if (shift <= (int8_t)0x70 && shift >= (int8_t)0x91) { + return (sign << 7); + } + int8_t exponent = shift - fp32Bias; + uint32_t mantissa = fp32Bits & 0x007FFFFF; + + // Adjust the exponent and mantissa for FP8 + exponent += fp8Bias; + + // Handle special cases and rounding (not shown for brevity) + // Assemble the FP8 value (manual bit manipulation) + uint8_t fp8 = + (sign << 7) | ((exponent & 0x1F) << 2) | ((mantissa >> 21) & 0x03); + + return fp8; +} + +static inline uint32_t cast_fp8x1_to_fp32x1(uint8_t fp8) { + uint8_t sign = (fp8 >> 7) & 0x01; + // Handle special cases (e.g., zero, infinity) + if ((fp8 & 0x7C) == 0) { + // Zero or subnormal (treated as zero) + return sign ? -0.0f : 0.0f; + } else if ((fp8 & 0x7C) == 0x7C) { + // Infinity + return sign ? -INFINITY : INFINITY; + } + + // Define the FP8 bias and the target FP32 bias + const int fp8Bias = 15; + const int fp32Bias = 127; + + // Extract sign, exponent, and mantissa from FP8 + int exponent = ((fp8 >> 2) & 0x1F) - fp8Bias; + uint8_t mantissa = fp8 & 0x03; + + // Adjust the exponent and mantissa for FP32 + exponent += fp32Bias; + + // Normalize the mantissa (the implicit leading 1 is added) + uint32_t mantissaFP32 = static_cast(mantissa) << 21; + + // Assemble the FP32 value + uint32_t fp32Bits = (static_cast(sign) << 31) | + (static_cast(exponent) << 23) | mantissaFP32; + + return fp32Bits; +} + +static inline float cast_fp8x1_to_fp32x1_f(uint8_t fp8) { + uint32_t fp32_i = cast_fp8x1_to_fp32x1(fp8); + float fp32 = *(float*)(&fp32_i); + return fp32; +} + +static inline __m256 cast_fp8x16_to_fp16x16(__m128 fp8x16) { + return (__m256)_mm256_cvte5m2_fp16((__m128i)fp8x16); +} + +static inline __m512 cast_fp8x16_to_fp32x16(__m128 fp8x16) { + __m512 res{0}; + // fp8x16 -> fp16x16 -> fp32x16 + __m256 fp16x16 = cast_fp8x16_to_fp16x16(fp8x16); + res = _mm512_cvtph_ps((__m256i)fp16x16); + return res; +} + +#endif diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 63f8466da9316..3ef3f33601630 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -132,10 +132,6 @@ def __init__( raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - if kv_cache_dtype != "auto": - raise NotImplementedError( - "Torch SDPA backend does not support FP8 KV cache. " - "Please use xFormers backend instead.") def forward( self,