diff --git a/CMakeLists.txt b/CMakeLists.txt index f4b9c3ec9c14f..7ce63d31e666c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -189,6 +189,7 @@ set(VLLM_EXT_SRC "csrc/cache_kernels.cu" "csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v2.cu" + "csrc/attention/vertical_slash_index.cu" "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" @@ -549,8 +550,8 @@ if(VLLM_FLASH_ATTN_SRC_DIR) else() FetchContent_Declare( vllm-flash-attn - GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c + GIT_REPOSITORY https://github.com/minminsun/flash-attention.git + GIT_TAG 260da6541a1d53a7562963bf7f6f8cfc04661ba3 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/csrc/attention/vertical_slash_index.cu b/csrc/attention/vertical_slash_index.cu new file mode 100644 index 0000000000000..2980cb0c4dbad --- /dev/null +++ b/csrc/attention/vertical_slash_index.cu @@ -0,0 +1,394 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include + +#include + +__device__ void save_blocks(int* block_offset, int64_t range_start, + int64_t range_end, int64_t block_size, + int64_t& block_count, int64_t kv_seqlen) { + if (range_start >= kv_seqlen) { + return; + } + if (range_end > kv_seqlen) { + range_end = kv_seqlen; + } + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[block_count++] = idx; + } +} + +__global__ void convert_vertical_slash_indexes_kernel( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + // if (head_idx != 0 || block_idx_m != 0) return; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + // if (head_idx != 0 || block_idx_m != 0) return; + + bool slash_finished = false; + while (1) { + // if (v > 100) return; + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, + tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, + tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, + tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + // const int BLOCK_SIZE_M = 64; + // const int BLOCK_SIZE_N = 64; + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count, + block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, + BLOCK_SIZE_N, NNZ_V, NNZ_S, causal); +} + +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal) { + // assert(block_size_M == 64); + // assert(block_size_N == 64); + + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + block_count.data_ptr(), block_offset.data_ptr(), + column_count.data_ptr(), column_index.data_ptr(), batch_size, + num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash, + causal); +} + +__global__ void convert_vertical_slash_indexes_kernel_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + const int* per_head_vertical_topkv, const int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + // MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S + // above is buffer size, use to compute offset) + NNZ_S = per_head_slash_topkv[head_idx]; + NNZ_V = per_head_vertical_topkv[head_idx]; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + // if (v > 100) return; + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, + tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, + tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, + tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* per_head_vertical_topkv, int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel_mergehead<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, + per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset, + column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, + NNZ_V, NNZ_S, causal); +} + +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, int64_t context_size, + int64_t block_size_M, int64_t block_size_N, bool causal) { + // assert(block_size_M == 64); + // assert(block_size_N == 64); + + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64_mergehead( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + vertical_indices_count.data_ptr(), + slash_indices_count.data_ptr(), block_count.data_ptr(), + block_offset.data_ptr(), column_count.data_ptr(), + column_index.data_ptr(), batch_size, num_heads, num_rows, + block_size_M, block_size_N, nnz_vertical, nnz_slash, causal); +} diff --git a/csrc/ops.h b/csrc/ops.h index 9efd9b0c24700..7c1805d0519b1 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -50,6 +50,31 @@ void paged_attention_v2( const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal); + +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, int64_t context_size, + int64_t block_size_M, int64_t block_size_N, bool causal); + void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 956258c1001d3..507d928f56228 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -50,6 +50,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); + ops.def( + "convert_vertical_slash_indexes(" + " Tensor! block_count, Tensor! block_offset, " + " Tensor! column_count, Tensor! column_index, " + " Tensor q_seqlens, Tensor q_seqlens, " + " Tensor vertical_indexes, Tensor slash_indexes, " + " int context_size, int block_size_M, int block_size_N, " + " bool causal) -> ()"); + ops.impl("convert_vertical_slash_indexes", torch::kCUDA, + &convert_vertical_slash_indexes); + + ops.def( + "convert_vertical_slash_indexes_mergehead(" + " Tensor! block_count, Tensor! block_offset, " + " Tensor! column_count, Tensor! column_index, " + " Tensor q_seqlens, Tensor q_seqlens, " + " Tensor vertical_indexes, Tensor slash_indexes, " + " Tensor vertical_indices_count, Tensor slash_indices_count, " + " int context_size, int block_size_M, int block_size_N, " + " bool causal) -> ()"); + ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, + &convert_vertical_slash_indexes_mergehead); + // Activation ops // Activation function used in SwiGLU. ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); diff --git a/examples/offline_inference_qwen_1m.py b/examples/offline_inference_qwen_1m.py new file mode 100644 index 0000000000000..346f57d415261 --- /dev/null +++ b/examples/offline_inference_qwen_1m.py @@ -0,0 +1,36 @@ +import os + +from vllm import LLM, SamplingParams + +os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN" +os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" + +with open(os.path.expanduser("~/vllm/64k.txt")) as f: + prompt = f.read() + +# Sample prompts. +prompts = [ + prompt, +] +# Create a sampling params object. +sampling_params = SamplingParams( + temperature=0.8, + top_p=0.7, + top_k=20, + detokenize=True, +) + +# Create an LLM. +llm = LLM(model=os.path.expanduser("~/models/qwen2.5-14b-1m-1231/"), + max_model_len=1048576, + tensor_parallel_size=4, + enforce_eager=True) + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index eb2f69df42624..dfe5cc1eb3f43 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -153,6 +153,101 @@ def paged_attention_rocm( kv_cache_dtype, k_scale, v_scale) +def convert_vertical_slash_indexes( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.zeros(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + block_offset = torch.zeros(batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_count = torch.zeros(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_index = torch.zeros(batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + + torch.ops._C.convert_vertical_slash_indexes( + block_count, block_offset, column_count, column_index, q_seqlens, + kv_seqlens, vertical_indexes, slash_indexes, context_size, + block_size_M, block_size_N, causal) + return block_count, block_offset, column_count, column_index + + +def convert_vertical_slash_indexes_mergehead( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + vertical_indices_count: torch. + Tensor, # [N_HEADS] : different head use different number of indices + slash_indices_count: torch.Tensor, + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.empty(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + block_offset = torch.empty(batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_count = torch.empty(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_index = torch.empty(batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + + torch.ops._C.convert_vertical_slash_indexes_mergehead( + block_count, block_offset, column_count, column_index, q_seqlens, + kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count, + slash_indices_count, context_size, block_size_M, block_size_N, causal) + return block_count, block_offset, column_count, column_index + + # pos encoding ops def rotary_embedding( positions: torch.Tensor, diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py new file mode 100644 index 0000000000000..19ea1eb0b1ea6 --- /dev/null +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -0,0 +1,1508 @@ +"""Attention layer with Dual chunk flash attention and sparse attention. +""" +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch.distributed +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.flash_attn import (FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, + FlashAttentionMetadataBuilder) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache, sparse_attn_func) + + +class DualChunkFlashAttentionBackend(FlashAttentionBackend): + + @staticmethod + def get_name() -> str: + return "DUAL_CHUNK_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["DualChunkFlashAttentionImpl"]: + return DualChunkFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["DualChunkFlashAttentionMetadata"]: + return DualChunkFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["DualChunkFlashAttentionMetadataBuilder"]: + return DualChunkFlashAttentionMetadataBuilder + + +@dataclass +class DualChunkFlashAttentionMetadata(FlashAttentionMetadata): + # Block size of the paged kv cache. + block_size: int = 16 + + # Original max position embeddings. + original_max_position_embeddings: int = 0 + + # Chunk size + chunk_size: int = 8192 + + # Local size + local_size: int = 1024 + + # Length scaling factor + scaling_factor: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for intra attention. + seq_lens_intra: Optional[torch.Tensor] = None + + # Max sequence length for intra attention. + max_seq_len_intra: Optional[int] = None + + # (batch_size, num_blocks). Block table for intra attention. + block_tables_intra: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for succ attention. + seq_lens_succ: Optional[torch.Tensor] = None + + # Max sequence length for succ attention. + max_seq_len_succ: Optional[int] = None + + # (batch_size, num_blocks). Block table for succ attention. + block_tables_succ: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for inter attention. + seq_lens_inter: Optional[torch.Tensor] = None + + # Max sequence length for inter attention. + max_seq_len_inter: Optional[int] = None + + _cached_prefill_metadata: Optional[ + "DualChunkFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["DualChunkFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + prefill_metadata = super().prefill_metadata + if prefill_metadata is None: + return None + + prefill_metadata = DualChunkFlashAttentionMetadata( + **prefill_metadata.asdict_zerocopy()) + + if self.original_max_position_embeddings > 0: + assert prefill_metadata.orig_seq_lens_tensor is not None + prefill_metadata.scaling_factor = ( + 0.1 * torch.log(prefill_metadata.orig_seq_lens_tensor / + self.original_max_position_embeddings) + + 1.0).clip(min=1) + + self._cached_prefill_metadata = prefill_metadata + return prefill_metadata + + @property + def decode_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + + decode_metadata = super().decode_metadata + if decode_metadata is None: + return None + + decode_metadata = DualChunkFlashAttentionMetadata( + **decode_metadata.asdict_zerocopy()) + + assert decode_metadata.orig_seq_lens_tensor is not None + assert decode_metadata.block_tables is not None + + cache_seq_lens = decode_metadata.orig_seq_lens_tensor + chunk_len = self.chunk_size - self.local_size + chunk_num_curr = (cache_seq_lens - 1) // chunk_len + batch_size = decode_metadata.num_decode_tokens + + if self.original_max_position_embeddings > 0: + decode_metadata.scaling_factor = (0.1 * torch.log( + cache_seq_lens / self.original_max_position_embeddings) + + 1.0).clip(min=1) + + seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len + max_seq_len_intra = seq_lens_intra.max().item() + decode_metadata.seq_lens_intra = seq_lens_intra + decode_metadata.max_seq_len_intra = max_seq_len_intra + + block_tables_intra = torch.zeros( + batch_size, + (max_seq_len_intra - 1) // self.block_size + 1, + dtype=decode_metadata.block_tables.dtype, + device=decode_metadata.block_tables.device, + ) + for i in range(batch_size): + st = chunk_num_curr[i] * chunk_len // self.block_size + ed = min( + st + (max_seq_len_intra - 1) // self.block_size + 1, + (cache_seq_lens[i] - 1) // self.block_size + 1, + ) + block_tables_intra[i, :ed - + st] = decode_metadata.block_tables[i, st:ed] + decode_metadata.block_tables_intra = block_tables_intra + + seq_lens_succ = (chunk_num_curr - + (chunk_num_curr - 1).clip(min=0) * chunk_len) + max_seq_len_succ = seq_lens_succ.max().item() + decode_metadata.seq_lens_succ = seq_lens_succ + decode_metadata.max_seq_len_succ = max_seq_len_succ + if max_seq_len_succ: + block_tables_succ = torch.zeros( + batch_size, + (max_seq_len_succ - 1) // self.block_size + 1, + dtype=decode_metadata.block_tables.dtype, + device=decode_metadata.block_tables.device, + ) + for i in range(batch_size): + st = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // + self.block_size) + ed = min( + st + (max_seq_len_succ - 1) // self.block_size + 1, + (cache_seq_lens[i] - 1) // self.block_size + 1, + ) + block_tables_succ[i, :ed - + st] = decode_metadata.block_tables[i, st:ed] + decode_metadata.block_tables_succ = block_tables_succ + + seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len + max_seq_len_inter = seq_lens_inter.max().item() + decode_metadata.seq_lens_inter = seq_lens_inter + decode_metadata.max_seq_len_inter = max_seq_len_inter + + self._cached_decode_metadata = decode_metadata + return decode_metadata + + +class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder): + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + attn_metadata = super().build(seq_lens, query_lens, + cuda_graph_pad_size, batch_size) + attn_metadata = DualChunkFlashAttentionMetadata( + **attn_metadata.asdict_zerocopy()) + + attn_metadata.block_size = self.runner.block_size + dual_chunk_attn_config = getattr(self.runner.model_config.hf_config, + "dual_chunk_attention_config", {}) + attn_metadata.original_max_position_embeddings = \ + dual_chunk_attn_config.get("original_max_position_embeddings", 0) + attn_metadata.chunk_size = dual_chunk_attn_config.get( + "chunk_size", 8192) + attn_metadata.local_size = dual_chunk_attn_config.get( + "local_size", 1024) + + return attn_metadata + + +class DualChunkFlashAttentionImpl(FlashAttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + The prompts might have different lengths, while the generation tokens + always have length 1. + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, + prefix: str = "", + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = ( + DualChunkFlashAttentionBackend.get_supported_head_sizes()) + + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + assert dual_chunk_attention_config is not None + self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192) + self.local_size = dual_chunk_attention_config.get("local_size", 1024) + self.original_max_position_embeddings = dual_chunk_attention_config.get( + "original_max_position_embeddings", 0) + self.sparse_attention_config = dual_chunk_attention_config.get( + "sparse_attention_config", None) + self.sparse_attention_enabled = dual_chunk_attention_config.get( + "sparse_attention_enabled", self.sparse_attention_config + is not None) + self.sparse_attention_threshold = dual_chunk_attention_config.get( + "sparse_attention_threshold", 32768) + self.sparse_attention_last_q = dual_chunk_attention_config.get( + "sparse_attention_last_q", 64) + self.dual_chunk_attention_config = dual_chunk_attention_config + + prefixes = prefix.split(".") + self.layer_idx = int(prefixes[prefixes.index("layers") + 1]) + + if self.sparse_attention_config: + self.sparse_attention_config = { + int(i): j + for i, j in self.sparse_attention_config[ + self.layer_idx].items() + } + start_head = self.num_heads * get_tensor_model_parallel_rank() + end_head = start_head + self.num_heads + self.sparse_attention_config = [ + self.sparse_attention_config[i] + for i in range(start_head, end_head) + ] + + if self.sparse_attention_enabled: + self.arange = torch.arange(self.sparse_attention_last_q, + device="cuda") + self.last_q_mask = (self.arange[None, None, :, None] >= + self.arange[None, None, None, :]) + + def forward( # type: ignore + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: DualChunkFlashAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + query_succ: torch.Tensor = None, + query_inter: torch.Tensor = None, + query_succ_critical: torch.Tensor = None, + query_inter_critical: torch.Tensor = None, + ) -> torch.Tensor: + """Forward pass with DualChunkFlashAttention. + Args: + query: shape = [num_tokens, num_heads * head_size] + query_succ: shape = [num_tokens, num_heads * head_size] + query_inter: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert ( + query_succ is not None and query_inter is not None + ), "query_succ and query_inter are required in Dual Chunk Attention." + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + query_succ = query_succ.view(-1, self.num_heads, self.head_size) + query_inter = query_inter.view(-1, self.num_heads, self.head_size) + query_succ_critical = query_succ_critical.view(-1, self.num_heads, + self.head_size) + query_inter_critical = query_inter_critical.view( + -1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.original_max_position_embeddings > 0: + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.scaling_factor is not None + assert prefill_meta.query_start_loc is not None + assert prefill_meta.orig_seq_lens is not None + current_start = 0 + query_start_loc_cpu = prefill_meta.query_start_loc.cpu() + for i, orig_seq_len in enumerate(prefill_meta.orig_seq_lens): + current_end = (current_start + + (query_start_loc_cpu[i + 1] - + query_start_loc_cpu[i]).item()) + key[current_start:current_end].mul_( + prefill_meta.scaling_factor[i]) + current_start = current_end + assert current_end <= attn_metadata.num_prefill_tokens + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta.scaling_factor is not None + scaling_factor = decode_meta.scaling_factor + key[attn_metadata.num_prefill_tokens:].mul_( + scaling_factor.unsqueeze(-1).unsqueeze(-1)) + + if kv_cache is not None and kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + k_scale, + v_scale, + ) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + output = torch.empty_like(query) + + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + decode_query_succ = query_succ[num_prefill_tokens:] + decode_query_inter = query_inter[num_prefill_tokens:] + + # QKV for prefill. + query = query[:num_prefill_tokens] + query_succ = query_succ[:num_prefill_tokens] + query_inter = query_inter[:num_prefill_tokens] + query_succ_critical = query_succ_critical[:num_prefill_tokens] + query_inter_critical = query_inter_critical[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention, called during the profiling run. + out = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + else: + # prefix-enabled attention + assert prefill_meta.seq_lens is not None + assert prefill_meta.orig_seq_lens is not None + output[:num_prefill_tokens] = ( + self._dual_chunk_flash_attn_prefill( + q=query, + q_succ=query_succ, + q_inter=query_inter, + q_succ_critical=query_succ_critical, + q_inter_critical=query_inter_critical, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + orig_seq_lens=prefill_meta.orig_seq_lens, + scaling_factor=prefill_meta.scaling_factor, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1), + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, + chunk_size=self.chunk_size, + local_size=self.local_size, + )) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output[num_prefill_tokens:] = ( + self._dual_chunk_flash_attn_decoding( + decode_query.unsqueeze(1), + decode_query_succ.unsqueeze(1), + decode_query_inter.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + chunk_size=self.chunk_size, + local_size=self.local_size, + original_max_position_embeddings=self. + original_max_position_embeddings, + decode_meta=decode_meta, + ).squeeze(1)) + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + + # def _complete_decode_metadata( + # self, + # attn_metadata: FlashAttentionMetadata, + # kv_cache: torch.Tensor, + # ): + # decode_meta = attn_metadata.decode_metadata + # if decode_meta is None or kv_cache.numel() == 0: + # return + # if hasattr(decode_meta, "mscale"): + # return + + # cache_seqlens = decode_meta.seq_lens_tensor + # block_table = decode_meta.block_tables + # chunk_len = self.chunk_size - self.local_size + # chunk_num_curr = (cache_seqlens - 1) // chunk_len + # batch_size = block_table.shape[0] + # block_size = kv_cache[1].shape[1] + + # if self.original_max_position_embeddings > 0: + # mscale = (0.1 * torch.log( + # cache_seqlens / self.original_max_position_embeddings) + + # 1.0).clip(min=1) + # seq_lens_intra = cache_seqlens - chunk_num_curr * chunk_len + # max_seq_len_intra = seq_lens_intra.max().item() + # block_table_intra = torch.zeros( + # batch_size, + # (max_seq_len_intra - 1) // block_size + 1, + # dtype=block_table.dtype, + # device=block_table.device, + # ) + # for i in range(batch_size): + # st = chunk_num_curr[i] * chunk_len // block_size + # ed = min( + # st + (max_seq_len_intra - 1) // block_size + 1, + # (cache_seqlens[i] - 1) // block_size + 1, + # ) + # block_table_intra[i, :ed - st] = block_table[i, st:ed] + # seq_lens_succ = (chunk_num_curr - + # (chunk_num_curr - 1).clip(min=0)) * chunk_len + # max_seq_len_succ = seq_lens_succ.max().item() + # if max_seq_len_succ: + # block_table_succ = torch.zeros( + # batch_size, + # (max_seq_len_succ - 1) // block_size + 1, + # dtype=block_table.dtype, + # device=block_table.device, + # ) + # for i in range(batch_size): + # st = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // + # block_size) + # ed = min( + # st + (max_seq_len_succ - 1) // block_size + 1, + # (cache_seqlens[i] - 1) // block_size + 1, + # ) + # block_table_succ[i, :ed - st] = block_table[i, st:ed] + # seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len + # max_seq_len_inter = seq_lens_inter.max().item() + + # if self.original_max_position_embeddings > 0: + # decode_meta.mscale = mscale + # decode_meta.seq_lens_intra = seq_lens_intra + # decode_meta.block_table_intra = block_table_intra + # decode_meta.seq_lens_succ = seq_lens_succ + # decode_meta.max_seq_len_succ = max_seq_len_succ + # if max_seq_len_succ: + # decode_meta.block_table_succ = block_table_succ + # decode_meta.seq_lens_inter = seq_lens_inter + # decode_meta.max_seq_len_inter = max_seq_len_inter + + def _dual_chunk_flash_attn_prefill( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + orig_seq_lens: List[int], + scaling_factor: torch.Tensor, + softmax_scale: float, + causal: Optional[bool] = True, + window_size: Tuple[int, int] = (-1, -1), + alibi_slopes: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + chunk_size: int = 8192, + local_size: int = 1024, + ): + if alibi_slopes is not None: + raise ValueError( + "Dual Chunk Attention does not support alibi_slopes") + if not causal: + raise ValueError( + "Dual Chunk Attention does not support causal=False") + if window_size != (-1, -1): + raise ValueError( + "Dual Chunk Attention does not support window_size") + + cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist() + cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist() + all_outputs = [] + + for i in range(0, len(cu_seqlens_q_cpu) - 1): + qs = cu_seqlens_q_cpu[i] + qe = cu_seqlens_q_cpu[i:i + 2][-1] + ks = cu_seqlens_k_cpu[i] + ke = cu_seqlens_k_cpu[i:i + 2][-1] + + current_q = q[qs:qe] + current_q_succ = q_succ[qs:qe] + current_q_inter = q_inter[qs:qe] + current_q_succ_critical = q_succ_critical[qs:qe] + current_q_inter_critical = q_inter_critical[qs:qe] + + if block_table is None: + current_k = k[ks:ke] + current_v = v[ks:ke] + current_block_table = None + current_orig_seq_lens = orig_seq_lens[i:i + 1] + else: + current_block_table = block_table[i:i + 1] + current_orig_seq_lens = orig_seq_lens[i:i + 1] + current_k = k + current_v = v + sparse_attn_enabled = ( + self.sparse_attention_enabled + and current_orig_seq_lens[0] > self.sparse_attention_threshold) + + if current_q.shape[0] == 0: + continue + + if current_k.shape[0] == 0: + all_outputs.append( + torch.zeros( + (current_q.shape[0], current_q.shape[1], v.shape[2]), + device=q.device, + dtype=q.dtype, + )) + continue + + current_output = torch.empty_like(current_q) + group_size = int(current_q.size(-2) / current_k.size(-2)) + + if sparse_attn_enabled: + num_device_q_heads = current_q.size(-2) + heads_vertical_size = torch.empty(size=(num_device_q_heads, ), + dtype=torch.int32) + heads_slash_size = torch.empty(size=(num_device_q_heads, ), + dtype=torch.int32) + for head_id in range(current_q.size(-2)): + ( + ty, + vertical_size, + slash_size, + _, + ) = self.sparse_attention_config[head_id] + assert ty == "vertical_and_slash", "only support slash mode" + + if vertical_size == 30: + vertical_size += 100 + heads_vertical_size[head_id] = vertical_size + heads_slash_size[head_id] = slash_size + + current_output = self._dual_chunk_flash_attn_prefill_func( + current_q, # allheads + current_q_succ, + current_q_inter, + current_q_succ_critical, + current_q_inter_critical, + current_k, + current_v, + current_block_table, + softmax_scale, + chunk_size, + local_size, + current_orig_seq_lens, + scaling_factor[i].item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + heads_vertical_size=heads_vertical_size, + heads_slash_size=heads_slash_size, + group_size=group_size) + else: + for head_id in range(current_q.size(-2)): + # (seq_len, num_heads, head_size) + current_q_head = current_q[:, head_id, :].unsqueeze(1) + current_q_succ_head = \ + current_q_succ[:, head_id, :].unsqueeze(1) + current_q_inter_head = \ + current_q_inter[:, head_id, :].unsqueeze(1) + current_q_succ_head_critical = \ + current_q_succ_critical[:, head_id, :].unsqueeze(1) + current_q_inter_head_critical = \ + current_q_inter_critical[:, head_id, :].unsqueeze(1) + if block_table is not None: + current_k_head = current_k[..., head_id // + group_size, :].unsqueeze(2) + current_v_head = current_v[..., head_id // + group_size, :].unsqueeze(2) + + else: + current_k_head = current_k[:, head_id, :].unsqueeze(1) + current_v_head = current_v[:, head_id, :].unsqueeze(1) + + current_out = self._dual_chunk_flash_attn_prefill_func( + current_q_head, + current_q_succ_head, + current_q_inter_head, + current_q_succ_head_critical, + current_q_inter_head_critical, + current_k_head, + current_v_head, + current_block_table, + softmax_scale, + chunk_size, + local_size, + current_orig_seq_lens, + scaling_factor.item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + ) + current_output[:, head_id:head_id + 1, :] = current_out + all_outputs.append(current_output) + return torch.cat(all_outputs, dim=0) + + def _dual_chunk_flash_attn_prefill_func( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + block_table, + softmax_scale: float, + chunk_size: int, + local_size: int, + current_orig_seq_lens: List[int], + scaling_factor: float, + k_length: int, + sparse_attn_enabled: Optional[bool] = True, + heads_vertical_size=None, + heads_slash_size=None, + group_size=None, + ): + flash_results = [] + chunk_len = chunk_size - local_size + + if block_table is not None: + block_size = v.shape[1] + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + else: + block_size = 1 + + if self.original_max_position_embeddings > 0: + softmax_scale = softmax_scale * scaling_factor + + begin = k_length - q.shape[0] + while begin < k_length: + flash_per_chunk = [] + + prev_chunk_end_pos = (begin // chunk_len) * chunk_len + next_chunk_end_pos = prev_chunk_end_pos + chunk_len + end = min(next_chunk_end_pos, k_length) + qbegin = begin - (k_length - q.shape[0]) + qend = end - (k_length - q.shape[0]) + + qk_chunks = [] + q_states_intra = q[qbegin:qend] + # choose critical token + if block_table is not None: + block_tables_intra = _get_block(block_table, block_size, + prev_chunk_end_pos, end) + k_states_intra = k[block_tables_intra[0]].view( + -1, *k.shape[-2:])[:(end - prev_chunk_end_pos)] + v_states_intra = v[block_tables_intra[0]].view( + -1, *v.shape[-2:])[:(end - prev_chunk_end_pos)] + else: + block_tables_intra = None + k_states_intra = k[prev_chunk_end_pos:end] + v_states_intra = v[prev_chunk_end_pos:end] + + if sparse_attn_enabled: + last_q_size = min(qend - qbegin, self.sparse_attention_last_q) + _, num_device_k_heads, head_dim = k_states_intra.shape + k_states_intra = (k_states_intra.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, head_dim)) + v_states_intra = (v_states_intra.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, head_dim)) + qk_chunks.append( + (q_states_intra.transpose(0, 1)[:, -last_q_size:] * + softmax_scale) @ k_states_intra.permute(1, 2, 0)) + + if prev_chunk_end_pos - chunk_len >= 0: + q_states_succ = q_succ[qbegin:qend] + q_states_succ_critical = q_succ_critical[qbegin:qend] + if block_table is not None: + block_tables_succ = _get_block( + block_table, block_size, + prev_chunk_end_pos - chunk_len, prev_chunk_end_pos) + k_states_succ = k[block_tables_succ[0]].view( + -1, *k.shape[-2:])[:chunk_len] + v_states_succ = v[block_tables_succ[0]].view( + -1, *v.shape[-2:])[:chunk_len] + else: + k_states_succ = k[prev_chunk_end_pos - + chunk_len:prev_chunk_end_pos] + v_states_succ = v[prev_chunk_end_pos - + chunk_len:prev_chunk_end_pos] + + if sparse_attn_enabled: + k_states_succ = (k_states_succ.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + v_states_succ = (v_states_succ.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + qk_chunks.append((q_states_succ_critical.transpose( + 0, 1)[:, -last_q_size:] * softmax_scale) + @ k_states_succ.permute(1, 2, 0)) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + q_states_inter = q_inter[qbegin:qend] + q_states_inter_critical = q_inter_critical[qbegin:qend] + if block_table is not None: + block_tables_inter = _get_block( + block_table, block_size, 0, + prev_chunk_end_pos - chunk_len) + k_states_inter = k[block_tables_inter[0]].view( + -1, *k.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] + v_states_inter = v[block_tables_inter[0]].view( + -1, *v.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] + else: + k_states_inter = k[:prev_chunk_end_pos - chunk_len] + v_states_inter = v[:prev_chunk_end_pos - chunk_len] + + if sparse_attn_enabled: + k_states_inter = (k_states_inter.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + v_states_inter = (v_states_inter.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + qk_chunks.append((q_states_inter_critical.transpose( + 0, 1)[:, -last_q_size:] * softmax_scale) + @ k_states_inter.permute(1, 2, 0)) + + if sparse_attn_enabled: + reversed_qk = qk_chunks[::-1] + qk = torch.cat(reversed_qk, dim=-1) + + qk[:, :, -last_q_size:] = torch.where( + self.last_q_mask[..., -last_q_size:, + -last_q_size:].to(qk.device), + qk[:, :, -last_q_size:], -torch.inf) + qk = F.softmax(qk, dim=-1, dtype=torch.float32) + + vertical = qk.sum(-2, keepdim=True) + vertical[..., :30] = torch.inf + + int32_max = 2147483647 # avoid sort + int32_min = -2147483648 + n_heads = qk.size()[0] + max_slash_topk = torch.max(heads_slash_size).item() + max_vertical_topk = torch.max(heads_vertical_size).item() + # store each head's slash topk, vertical topk + vertical = vertical.reshape((n_heads, -1)) + # prevent out of range when prompt size < max_vertical_topk + max_vertical_topk = min(vertical.shape[-1], max_vertical_topk) + vertical_topk_buffer = torch.topk(vertical, max_vertical_topk, + -1).indices + slash_topk_buffer = torch.empty(size=(n_heads, max_slash_topk), + dtype=torch.int64, + device=qk.device) + for head_i in range(n_heads): + # (nqheads=1, lastq, k_len) + head_score = qk[head_i:head_i + 1, :, :] + slash_scores = _sum_all_diagonal_matrix(head_score) + if head_score.size(1) != 1: + # drop right up corner + slash_scores = slash_scores[..., :-last_q_size + 1] + slash_scores[..., -100:] = torch.inf + + head_slash_size = heads_slash_size[head_i] + head_slash_size = min(head_slash_size, vertical.size(-1)) + slash_topk = torch.topk(slash_scores, head_slash_size, + -1).indices + #(nheads, max_topk) + slash_topk_buffer[head_i, :head_slash_size] = slash_topk + + # reset heads topk + heads_slash_size[head_i] = head_slash_size + heads_vertical_size[head_i] = min( + heads_vertical_size[head_i], max_vertical_topk) + + # store + vertical_buffer = torch.full((n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + succ_vertical_buffer = torch.full((n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + succ_slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + inter_vertical_buffer = torch.full( + (n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + inter_slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + + vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + succ_vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + succ_slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + inter_vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + inter_slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + + for head_i in range(n_heads): + vertical_topk = vertical_topk_buffer[ + head_i, :heads_vertical_size[head_i]] + # intra + intra_vertical_indices = vertical_topk[ + vertical_topk >= + prev_chunk_end_pos] - prev_chunk_end_pos + if intra_vertical_indices.nelement() == 0: + intra_vertical_indices = torch.cat([ + intra_vertical_indices, + torch.arange(0, + k_states_intra.size(0), + max(1, + k_states_intra.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + slash_topk = slash_topk_buffer[ + head_i, :heads_slash_size[head_i]] + intra_slash_indices = ( + (qk.size(-1) - 1) - + slash_topk[slash_topk >= prev_chunk_end_pos]) + # fill buffer + v_count = intra_vertical_indices.nelement() + s_count = intra_slash_indices.nelement() + vertical_size_buffer[head_i] = v_count + slash_sizes_buffer[head_i] = s_count + vertical_buffer[head_i, :v_count].copy_( + intra_vertical_indices) + slash_buffer[head_i, :s_count].copy_(intra_slash_indices) + # succ + if prev_chunk_end_pos - chunk_len >= 0: + succ_vertical_indices = vertical_topk[ + (vertical_topk < prev_chunk_end_pos) + & (vertical_topk >= prev_chunk_end_pos - + chunk_len)] - (prev_chunk_end_pos - chunk_len) + # TODO: support no vertical + if succ_vertical_indices.nelement() == 0: + succ_vertical_indices = torch.cat([ + succ_vertical_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, + k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + succ_slash_indices = ( + (prev_chunk_end_pos + (qend - qbegin) - 1) - + slash_topk[((slash_topk >= + (prev_chunk_end_pos - chunk_len)) & + (slash_topk < (prev_chunk_end_pos + + (qend - qbegin))))]) + if succ_slash_indices.nelement() == 0: + succ_slash_indices = torch.cat([ + succ_slash_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, + k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + # fill buffer + v_count = succ_vertical_indices.nelement() + s_count = succ_slash_indices.nelement() + succ_vertical_size_buffer[head_i] = v_count + succ_slash_sizes_buffer[head_i] = s_count + succ_vertical_buffer[head_i, :v_count].copy_( + succ_vertical_indices) + succ_slash_buffer[head_i, :s_count].copy_( + succ_slash_indices) + + if prev_chunk_end_pos - 2 * chunk_len >= 0: + inter_vertical_indices = vertical_topk[ + vertical_topk < prev_chunk_end_pos - chunk_len] + + if inter_vertical_indices.nelement() == 0: + inter_vertical_indices = torch.cat([ + inter_vertical_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, + k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + inter_slash_indices = ( + (prev_chunk_end_pos - chunk_len + + (qend - qbegin) - 1) - + slash_topk[slash_topk < + (prev_chunk_end_pos - chunk_len + + (qend - qbegin))]) + if inter_slash_indices.nelement() == 0: + inter_slash_indices = torch.cat([ + inter_slash_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, + k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + # fill buffer + v_count = inter_vertical_indices.nelement() + s_count = inter_slash_indices.nelement() + inter_vertical_size_buffer[head_i] = v_count + inter_slash_sizes_buffer[head_i] = s_count + inter_vertical_buffer[head_i, :v_count].copy_( + inter_vertical_indices) + inter_slash_buffer[head_i, :s_count].copy_( + inter_slash_indices) + else: + intra_vertical_indices, intra_slash_indices = None, None + succ_vertical_indices, succ_slash_indices = None, None + inter_vertical_indices, inter_slash_indices = None, None + + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + stage="intra", + vertical_indices=vertical_buffer, + slash_indices=slash_buffer, + vertical_indices_count=vertical_size_buffer, + slash_indices_count=slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + stage="intra", + vertical_indices=intra_vertical_indices, + slash_indices=intra_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + stage="succ", + vertical_indices=succ_vertical_buffer, + slash_indices=succ_slash_buffer, + vertical_indices_count=succ_vertical_size_buffer, + slash_indices_count=succ_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + stage="succ", + vertical_indices=succ_vertical_indices, + slash_indices=succ_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + stage="inter", + vertical_indices=inter_vertical_buffer, + slash_indices=inter_slash_buffer, + vertical_indices_count=inter_vertical_size_buffer, + slash_indices_count=inter_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + stage="inter", + vertical_indices=inter_vertical_indices, + slash_indices=inter_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + flash_results.append(flash_per_chunk) + begin = end + + attn_output = self._merge_attn_outputs(flash_results) + del flash_results + return attn_output + + def _do_flash_attn( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + softmax_scale: float, + causal: bool = True, + block_table: torch.Tensor = None, + max_seqlen_k: Optional[int] = None, + stage: str = "intra", + vertical_indices: Optional[torch.Tensor] = None, + slash_indices: Optional[torch.Tensor] = None, + vertical_indices_count: Optional[torch.Tensor] = None, + slash_indices_count: Optional[torch.Tensor] = None, + mergehead_softmax_scale: Optional[float] = None, + sparse_attn_enabled: Optional[bool] = False, + ): + if max_seqlen_k is None: + max_seqlen_k = key_states.shape[0] + + q_len = query_states.shape[0] + q_heads = query_states.shape[1] + h_dim = query_states.shape[-1] + + if sparse_attn_enabled: + assert slash_indices is not None + if stage == "intra": + assert causal + else: + assert not causal + + query_states = query_states.unsqueeze(0).transpose(1, 2) + key_states = key_states.unsqueeze(0).transpose(1, 2) + value_states = value_states.unsqueeze(0).transpose(1, 2) + + q = query_states + k = key_states + v = value_states + + if (vertical_indices_count is not None and \ + slash_indices_count is not None): + assert mergehead_softmax_scale is not None + + res, s_lse = _vertical_slash_sparse_attention( + q, + k, + v, + vertical_indices, + slash_indices, + mergehead_softmax_scale, + causal=causal, + stage=stage, + vertical_indices_count=vertical_indices_count, + slash_indices_count=slash_indices_count) + res = res.view(q_heads, q_len, + h_dim).transpose(0, 1) # (qlen,nhead,h_dim) + s_lse = s_lse.view( + q_heads, q_len, + 1).squeeze(-1).unsqueeze(0).float() # (1, nhead,qlen) + else: + res, s_lse = _vertical_slash_sparse_attention(q, + k, + v, + vertical_indices, + slash_indices, + softmax_scale, + causal=causal, + stage=stage) + res = res.view(q_len, q_heads, h_dim) + s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float() + return res, s_lse + + output, softmax_lse = flash_attn_varlen_func( + q=query_states.bfloat16(), + k=key_states.bfloat16(), + v=value_states.bfloat16(), + softmax_scale=softmax_scale, + cu_seqlens_q=torch.tensor([0, query_states.shape[0]], + dtype=torch.int32, + device=query_states.device), + max_seqlen_q=query_states.shape[0], + cu_seqlens_k=torch.tensor([0, max_seqlen_k], + dtype=torch.int32, + device=query_states.device), + max_seqlen_k=max_seqlen_k, + causal=causal, + block_table=block_table, + return_softmax_lse=True, + ) + softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, + 2).float() + return output, softmax_lse + + def _merge_attn_outputs( + self, + flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]], + return_lse: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs_all = [] + logits_all = [] + + for flash_per_chunk in flash_results: + if len(flash_per_chunk) == 1: + attn_outputs_all.append(flash_per_chunk[0][0]) + if return_lse: + logits_all.append(flash_per_chunk[0][1]) + continue + + attn_outputs = torch.stack([ + flash_attn_output[0] for flash_attn_output in flash_per_chunk + ]) + logits = torch.stack([ + flash_attn_output[1] for flash_attn_output in flash_per_chunk + ]) + logits = logits.to(torch.float32) + + if return_lse: + max_val = torch.max(logits, dim=0).values + diff = torch.abs(logits[0] - logits[1]) + log_sum_exp = max_val + torch.log1p(torch.exp(-diff)) + logits_all.append(log_sum_exp) + + max_logits = torch.max(logits, dim=0).values + stable_logits = logits - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1) + attn_outputs_all.append(attn_outputs.sum(dim=0)) + + if return_lse: + return (torch.cat(attn_outputs_all, + dim=0), torch.cat(logits_all, dim=-1)) + else: + return torch.cat(attn_outputs_all, dim=0) + + def _dual_chunk_flash_attn_decoding( + self, + query: torch.Tensor, + query_succ: torch.Tensor, + query_inter: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + causal: bool, + alibi_slopes: Optional[torch.Tensor], + chunk_size: int, + local_size: int, + original_max_position_embeddings: int, + decode_meta: DualChunkFlashAttentionMetadata, + ): + if not causal: + raise ValueError( + "Dual Chunk Attention does not support causal=False") + + block_size = value_cache.shape[1] + chunk_len = chunk_size - local_size + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + if original_max_position_embeddings > 0: + assert decode_meta.scaling_factor is not None + scaling_factor = decode_meta.scaling_factor + query = (query * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype + ) # possible for numerical issue, need to fused in the kernel + query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype) + query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype) + outputs_list = [] + softmax_lses_list = [] + + # intra-attention + intra_output, intra_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query, + key_cache, + value_cache, + decode_meta.block_tables_intra, + decode_meta.seq_lens_intra, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(intra_output) + softmax_lses_list.append(intra_softmax_lse) + + # succ-attention + if decode_meta.max_seq_len_succ: + succ_output, succ_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_succ, + key_cache, + value_cache, + decode_meta.block_tables_succ, + decode_meta.seq_lens_succ, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(succ_output) + softmax_lses_list.append(succ_softmax_lse) + + # inter-attention + if decode_meta.max_seq_len_inter: + inter_output, inter_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_inter, + key_cache, + value_cache, + block_table[:, :decode_meta.max_seq_len_inter], + decode_meta.seq_lens_inter, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(inter_output) + softmax_lses_list.append(inter_softmax_lse) + outputs = torch.stack(outputs_list, dim=0) + del outputs_list + softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32) + del softmax_lses_list + max_logits = torch.max(softmax_lses, dim=0).values + stable_logits = softmax_lses - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + outputs *= lse_s.unsqueeze(-1).transpose(2, 3) + return outputs.sum(0) + + def _dual_chunk_flash_attn_decoding_with_exp_sums( + self, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + ): + out, softmax_lse = flash_attn_with_kvcache( + q=query, + k_cache=key_cache, + v_cache=value_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + softmax_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + return_softmax_lse=True, + ) + mask = (cache_seqlens == 0) + out[mask] = 0 + softmax_lse[mask] = -float("inf") + return out, softmax_lse + + +def _vertical_slash_sparse_attention( + query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + softmax_scale: float, + causal: bool = True, + stage: str = "intra", + block_size_M: int = 64, + block_size_N: int = 64, + vertical_indices_count: torch.Tensor = None, # [N_HEADS,] + slash_indices_count: torch.Tensor = None, +): + if stage == "intra": + assert causal + else: + assert not causal + + batch_size, num_heads, context_size, head_dim = query.shape + _, _, kv_seq_len, _ = key.shape + + if head_dim not in [16, 32, 64, 128, 256, 512]: + target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim + query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) + key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) + value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) + + v_idx = v_idx.to(torch.int32).reshape( + (batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape( + (batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] + q_seqlens = torch.tensor([context_size], + dtype=torch.int32, + device=query.device) + kv_seqlens = torch.tensor([kv_seq_len], + dtype=torch.int32, + device=query.device) + + if vertical_indices_count is not None and slash_indices_count is not None: + ( + block_count, + block_offset, + column_count, + column_index, + ) = ops.convert_vertical_slash_indexes_mergehead( + q_seqlens, kv_seqlens, v_idx, s_idx, vertical_indices_count, + slash_indices_count, context_size, block_size_M, block_size_N, + causal) + else: + ( + block_count, + block_offset, + column_count, + column_index, + ) = ops.convert_vertical_slash_indexes(q_seqlens, kv_seqlens, v_idx, + s_idx, context_size, + block_size_M, block_size_N, + causal) + + q = query.transpose(1, 2).contiguous() + k = key.transpose(1, 2).contiguous() + v = value.transpose(1, 2).contiguous() + out, lse = sparse_attn_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + causal=causal, + softmax_scale=softmax_scale, + return_softmax_lse=True, + ) + out = out.transpose(1, 2).contiguous() + softmax_lse = lse.reshape(*lse.shape, 1) + return (out[..., :context_size, :head_dim], + softmax_lse[..., :context_size, :]) + + +def _sum_all_diagonal_matrix(mat: torch.tensor): + h, n, m = mat.shape + # Zero matrix used for padding + zero_mat = torch.zeros((h, n, n), device=mat.device) + # pads the matrix on left and right + mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) + # Change the strides + mat_strided = mat_padded.as_strided((1, n, n + m), + (n * (2 * n + m), 2 * n + m + 1, 1)) + # Sums the resulting matrix's columns + sum_diags = torch.sum(mat_strided, 1) + return sum_diags[:, 1:] # drop left bottom corner + + +def _get_block(block_table: torch.Tensor, block_size: int, begin: int, + end: int): + begin_block = begin // block_size + end_block = (end - 1) // block_size + 1 + return block_table[:, begin_block:end_block] diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 23ea244f07dfe..8982a97fd029e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -103,6 +103,11 @@ class FlashAttentionMetadata(AttentionMetadata): # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The orig sequence length per sequence. + orig_seq_lens: Optional[List[int]] + # orig_seq_lens stored as a tensor. + orig_seq_lens_tensor: Optional[torch.Tensor] + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -210,6 +215,10 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: self.seq_lens[:self.num_prefills]) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills]) + orig_seq_lens = (None if self.orig_seq_lens is None else + self.orig_seq_lens[:self.num_prefills]) + orig_seq_lens_tensor = (None if self.orig_seq_lens_tensor is None else + self.orig_seq_lens_tensor[:self.num_prefills]) seq_start_loc = (None if self.seq_start_loc is None else self.seq_start_loc[:self.num_prefills + 1]) context_lens_tensor = (None if self.context_lens_tensor is None else @@ -226,6 +235,8 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: multi_modal_placeholder_index_maps, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + orig_seq_lens=orig_seq_lens, + orig_seq_lens_tensor=orig_seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_query_len=0, @@ -259,6 +270,8 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: self.slot_mapping[self.num_prefill_tokens:]) seq_lens_tensor = (None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:]) + orig_seq_lens_tensor = (None if self.orig_seq_lens_tensor is None else + self.orig_seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) @@ -270,6 +283,8 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: multi_modal_placeholder_index_maps=None, seq_lens=None, seq_lens_tensor=seq_lens_tensor, + orig_seq_lens=None, + orig_seq_lens_tensor=orig_seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=self.max_query_len, max_prefill_seq_len=0, @@ -372,6 +387,7 @@ class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.orig_seq_lens: List[int] = [] self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] @@ -422,6 +438,7 @@ def _add_seq_group( else: self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) + self.orig_seq_lens.append(seq_len) # Compute block table. # TODO(sang): Combine chunked prefill and prefix caching by @@ -447,7 +464,7 @@ def _add_seq_group( context_len, self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, + curr_seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) def _get_graph_runner_block_tables( @@ -530,6 +547,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) + orig_seq_lens_tensor = async_tensor_h2d(self.orig_seq_lens, torch.int, + device, self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, @@ -549,8 +568,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, + orig_seq_lens=self.orig_seq_lens, multi_modal_placeholder_index_maps=placeholder_index_maps, seq_lens_tensor=seq_lens_tensor, + orig_seq_lens_tensor=orig_seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 56cc43430301f..561a71ccc1e17 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -318,6 +318,7 @@ def graph_capture_get_metadata_for_batch( multi_modal_placeholder_index_maps=None, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], + orig_seq_lens=None, max_query_len=1, max_decode_query_len=1, max_prefill_seq_len=0, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index f1b3598e60b54..78c24fd06b6ef 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,5 +1,5 @@ """Attention layer.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -42,6 +42,7 @@ def __init__( per_layer_sliding_window: Optional[int] = None, prefix: str = "", attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() if per_layer_sliding_window is not None: @@ -95,13 +96,18 @@ def __init__( block_size, is_attention_free, blocksparse_params is not None) impl_cls = attn_backend.get_impl_cls() - self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type) + self.impl = impl_cls( + num_heads, head_size, scale, num_kv_heads, alibi_slopes, + sliding_window, kv_cache_dtype, blocksparse_params, + logits_soft_cap, attn_type, **{ + "dual_chunk_attention_config": dual_chunk_attention_config, + "prefix": prefix, + } if dual_chunk_attention_config is not None else {}) self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads self.backend = backend_name_to_enum(attn_backend.get_name()) + self.dual_chunk_attention_config = dual_chunk_attention_config # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # torch.compile works by registering the attention as one giant @@ -129,12 +135,26 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, + query_succ_and_inter: Optional[Tuple[torch.Tensor, torch.Tensor, + torch.Tensor, + torch.Tensor]] = None, ) -> torch.Tensor: + if self.dual_chunk_attention_config: + assert query_succ_and_inter is not None + dca_kwargs = { + "query_succ": query_succ_and_inter[0], + "query_inter": query_succ_and_inter[1], + "query_succ_critical": query_succ_and_inter[2], + "query_inter_critical": query_succ_and_inter[3], + } if query_succ_and_inter else {} + else: + dca_kwargs = {} + if self.use_direct_call: return self.impl.forward(query, key, value, kv_cache, attn_metadata, self._k_scale, - self._v_scale) + self._v_scale, **dca_kwargs) elif self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) @@ -148,11 +168,13 @@ def forward( if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) torch.ops.vllm.unified_attention_with_output( - query, key, value, output, kv_cache, self.layer_name) + query, key, value, output, kv_cache, self.layer_name, + **dca_kwargs) return output.view(-1, hidden_size) else: return torch.ops.vllm.unified_attention(query, key, value, - kv_cache, self.layer_name) + kv_cache, self.layer_name, + **dca_kwargs) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore @@ -232,12 +254,22 @@ def unified_attention( value: torch.Tensor, kv_cache: torch.Tensor, layer_name: str, + query_succ: torch.Tensor = None, + query_inter: torch.Tensor = None, + query_succ_critical: torch.Tensor = None, + query_inter_critical: torch.Tensor = None, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.dynamic_forward_context self = forward_context.static_forward_context[layer_name] + dca_kwargs = { + "query_succ": query_succ, + "query_inter": query_inter, + "query_succ_critical": query_succ_critical, + "query_inter_critical": query_inter_critical, + } if self.dual_chunk_attention_config else {} return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._k_scale, self._v_scale) + self._k_scale, self._v_scale, **dca_kwargs) def unified_attention_fake( @@ -246,6 +278,10 @@ def unified_attention_fake( value: torch.Tensor, kv_cache: torch.Tensor, layer_name: str, + query_succ: torch.Tensor = None, + query_inter: torch.Tensor = None, + query_succ_critical: torch.Tensor = None, + query_inter_critical: torch.Tensor = None, ) -> torch.Tensor: return torch.empty_like(query).contiguous() @@ -266,10 +302,20 @@ def unified_attention_with_output( output: torch.Tensor, kv_cache: torch.Tensor, layer_name: str, + query_succ: torch.Tensor = None, + query_inter: torch.Tensor = None, + query_succ_critical: torch.Tensor = None, + query_inter_critical: torch.Tensor = None, ) -> None: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.dynamic_forward_context self = forward_context.static_forward_context[layer_name] + dca_kwargs = { + "query_succ": query_succ, + "query_inter": query_inter, + "query_succ_critical": query_succ_critical, + "query_inter_critical": query_inter_critical, + } if self.dual_chunk_attention_config else {} self.impl.forward(query, key, value, @@ -277,7 +323,8 @@ def unified_attention_with_output( attn_metadata, self._k_scale, self._v_scale, - output=output) + output=output, + **dca_kwargs) def unified_attention_with_output_fake( @@ -287,6 +334,10 @@ def unified_attention_with_output_fake( output: torch.Tensor, kv_cache: torch.Tensor, layer_name: str, + query_succ: torch.Tensor = None, + query_inter: torch.Tensor = None, + query_succ_critical: torch.Tensor = None, + query_inter_critical: torch.Tensor = None, ) -> None: return diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index d263839705690..e529447924298 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -125,6 +125,10 @@ def _cached_get_attn_backend( from vllm.v1.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend as FlashAttentionBackendV1) return FlashAttentionBackendV1 + if backend == _Backend.DUAL_CHUNK_FLASH_ATTN: + from vllm.attention.backends.dual_chunk_flash_attn import ( # noqa: F401 + DualChunkFlashAttentionBackend) + return DualChunkFlashAttentionBackend if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 diff --git a/vllm/config.py b/vllm/config.py index 8b824a1fca511..160cc58c09c0c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -632,6 +632,23 @@ def _verify_bnb_config(self) -> None: "fallback to the eager mode.") self.enforce_eager = True + def verify_dual_chunk_attention_config( + self, + load_config: "LoadConfig", + ) -> None: + if hasattr(self.hf_config, "dual_chunk_attention_config"): + # Try loading the sparse attention config + from vllm.model_executor.model_loader.weight_utils import ( + get_sparse_attention_config) + sparse_attn_config = get_sparse_attention_config(self, load_config) + if sparse_attn_config: + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_config"] = sparse_attn_config + if "sparse_attention_enabled" not in \ + self.hf_config.dual_chunk_attention_config: + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_enabled"] = True + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: @@ -3133,6 +3150,8 @@ def __post_init__(self): self.speculative_config, self.device_config) self.model_config.verify_with_parallel_config(self.parallel_config) + self.model_config.verify_dual_chunk_attention_config( + self.load_config) if self.cache_config is not None: self.cache_config.verify_with_parallel_config(self.parallel_config) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 3fcd81a3c4213..13942ad8059cd 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -931,6 +931,175 @@ def get_next_input_positions( ] +@CustomOp.register("dual_chunk_rotary_embedding") +class DualChunkRotaryEmbedding(CustomOp): + """Rotary positional embedding for Dual Chunk Attention.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + chunk_size: int, + local_size: int, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.chunk_size = chunk_size + self.local_size = local_size + self.dtype = dtype + self.device = torch.device(f"cuda:{torch.cuda.current_device()}") + (q_cache, qc_cache, k_cache, qc_no_clamp_cache, + q_inter_cache) = self._compute_cos_sin_cache() + + self.register_buffer("cos_sin_q_cache", q_cache, persistent=False) + self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False) + self.register_buffer("cos_sin_k_cache", k_cache, persistent=False) + self.register_buffer("cos_sin_qc_no_clamp_cache", + qc_no_clamp_cache, + persistent=False) + self.register_buffer("cos_sin_q_inter_cache", + q_inter_cache, + persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + chunk_len = self.chunk_size - self.local_size + q_t = torch.arange(chunk_len, dtype=torch.float) + qc_t = (torch.arange(chunk_len, dtype=torch.float) + + chunk_len).clamp(max=self.chunk_size) + k_t = torch.arange(self.max_position_embeddings, + dtype=torch.float) % chunk_len + + # count from chunk_len, no clamp(self.chunk_size) restriction + qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len + # count from self.chunk_size for q_inter's rope + q_inter_t = torch.arange(chunk_len, + dtype=torch.float) + self.chunk_size + + q_freqs = torch.einsum("i,j -> ij", q_t, inv_freq) + qc_freqs = torch.einsum("i,j -> ij", qc_t, inv_freq) + k_freqs = torch.einsum("i,j -> ij", k_t, inv_freq) + qc_no_clamp_freqs = torch.einsum("i,j -> ij", qc_no_clamp_t, inv_freq) + q_inter_freqs = torch.einsum("i,j -> ij", q_inter_t, inv_freq) + q_cos = q_freqs.cos() + q_sin = q_freqs.sin() + qc_cos = qc_freqs.cos() + qc_sin = qc_freqs.sin() + k_cos = k_freqs.cos() + k_sin = k_freqs.sin() + qc_no_clamp_cos = qc_no_clamp_freqs.cos() + qc_no_clamp_sin = qc_no_clamp_freqs.sin() + q_inter_cos = q_inter_freqs.cos() + q_inter_sin = q_inter_freqs.sin() + q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), + dim=-1).to(dtype=self.dtype, + device=self.device) + q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), + dim=-1).to(dtype=self.dtype, + device=self.device) + return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + else: + query_pass = None + key_pass = None + + positions_with_offsets = (torch.add(positions, offsets) + if offsets is not None else positions) + key = self._apply_rotary_embedding( + self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass) + chunk_len = self.chunk_size - self.local_size + query = self._apply_rotary_embedding( + self.cos_sin_q_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_succ = self._apply_rotary_embedding( + self.cos_sin_qc_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_inter = self._apply_rotary_embedding( + self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1), + query_rot, query_pass) + query_succ_critical = self._apply_rotary_embedding( + self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_inter_critical = self._apply_rotary_embedding( + self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + + return (query, key, query_succ, query_inter, query_succ_critical, + query_inter_critical) + + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin + + if self.rotary_dim < self.head_size: + hidden = torch.cat((hidden_rot, hidden_pass), dim=-1) + else: + hidden = hidden_rot + hidden = hidden.flatten(-2) + return hidden.squeeze(0) + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + s += f", chunk_size={self.chunk_size}, local_size={self.local_size}" + return s + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -943,6 +1112,7 @@ def get_rope( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() @@ -955,14 +1125,35 @@ def get_rope( rope_scaling_args = tuple(rope_scaling_tuple.items()) else: rope_scaling_args = None + + if dual_chunk_attention_config is not None: + dual_chunk_attention_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in dual_chunk_attention_config.items() + if k != "sparse_attention_config" + } + dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) + else: + dual_chunk_attention_args = None + if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) key = (head_size, rotary_dim, max_position, base, is_neox_style, - rope_scaling_args, dtype) + rope_scaling_args, dual_chunk_attention_args, dtype) if key in _ROPE_DICT: return _ROPE_DICT[key] - if rope_scaling is None: + if dual_chunk_attention_config is not None: + extra_kwargs = { + k: v + for k, v in dual_chunk_attention_config.items() + if k in ("chunk_size", "local_size") + } + rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + **extra_kwargs) + elif rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) else: diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 8aa0c98df70d2..9b5731c7e7002 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -212,6 +212,39 @@ def get_quant_config(model_config: ModelConfig, return quant_cls.from_config(config) +def get_sparse_attention_config( + model_config: ModelConfig, + load_config: LoadConfig, + sparse_attention_config_filename: str = "sparse_attention_config.json", +) -> Dict[str, Any]: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + config_file = os.path.join(hf_folder, sparse_attention_config_filename) + if not os.path.exists(config_file): + return {} + + # Load the sparse attention config. + with open(config_file) as f: + config = json.load(f) + logger.info("Loaded sparse attention config from %s", config_file) + + return config + + def download_weights_from_hf( model_name_or_path: str, cache_dir: Optional[str], diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 01745b5fd53e1..4c92074d0b752 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -98,17 +98,20 @@ def forward(self, x): class Qwen2Attention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[Tuple] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[Dict[str, + Any]] = None) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -130,6 +133,7 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta + self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear( hidden_size, @@ -154,15 +158,20 @@ def __init__(self, max_position=max_position, base=self.rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=attn_type) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type, + **{ + "dual_chunk_attention_config": dual_chunk_attention_config, + } if dual_chunk_attention_config else {}) def forward( self, @@ -173,8 +182,17 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + + rotary_embedded = self.rotary_emb(positions, q, k) + q, k = rotary_embedded[:2] + query_succ_and_inter = rotary_embedded[2:] + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata=attn_metadata, + query_succ_and_inter=query_succ_and_inter) + output, _ = self.o_proj(attn_output) return output @@ -193,6 +211,9 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) # By default, Qwen2 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable @@ -214,6 +235,7 @@ def __init__( rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index ba70243c6533d..6822edcd1f1f8 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -169,6 +169,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -192,6 +193,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear( hidden_size, @@ -215,14 +217,19 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "dual_chunk_attention_config": dual_chunk_attention_config, + } if dual_chunk_attention_config else {}) def forward( self, @@ -233,8 +240,17 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + + rotary_embedded = self.rotary_emb(positions, q, k) + q, k = rotary_embedded[:2] + query_succ_and_inter = rotary_embedded[2:] + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata=attn_metadata, + query_succ_and_inter=query_succ_and_inter) + output, _ = self.o_proj(attn_output) return output @@ -252,6 +268,9 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Qwen2MoeAttention( @@ -264,6 +283,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + dual_chunk_attention_config=dual_chunk_attention_config, ) # Note: Qwen/Qwen2-57B-A14B-Instruct does not have diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index ddccaa2ce0148..e14c0523fdb02 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -25,6 +25,7 @@ def in_wsl() -> bool: class _Backend(enum.Enum): FLASH_ATTN = enum.auto() FLASH_ATTN_VLLM_V1 = enum.auto() + DUAL_CHUNK_FLASH_ATTN = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1c6d1bbee78ee..8e3f53d24446c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -490,6 +490,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # already computed) and sequence length (total number of tokens). seq_len = seq_data.get_len() + orig_seq_len = seq_data.get_len() if inter_data.is_prompt: context_len = seq_data.get_num_computed_tokens() seq_len = min(seq_len, context_len + token_chunk_size) @@ -504,7 +505,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, token_types = seq_group_metadata.token_type_ids inter_data.seq_lens[seq_idx] = seq_len - inter_data.orig_seq_lens[seq_idx] = seq_len + inter_data.orig_seq_lens[seq_idx] = orig_seq_len inter_data.context_lens[seq_idx] = context_len inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))