From a091e2da3e3fcb4c63c8206839d7240a2a2a176a Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 16 Sep 2024 17:47:19 +0200 Subject: [PATCH 001/116] [Kernel] Enable 8-bit weights in Fused Marlin MoE (#8032) Co-authored-by: Dipika --- csrc/moe/marlin_moe_ops.cu | 537 +++++++++++++----- csrc/moe/marlin_moe_ops.h | 7 +- csrc/moe/torch_bindings.cpp | 8 +- tests/kernels/test_moe.py | 18 +- tests/weight_loading/models-large.txt | 3 +- .../run_model_weight_loading_test.sh | 0 vllm/_custom_ops.py | 2 +- .../layers/fused_moe/fused_marlin_moe.py | 44 +- .../layers/fused_moe/fused_moe.py | 2 +- .../compressed_tensors_moe.py | 8 +- .../layers/quantization/gptq_marlin.py | 1 + vllm/model_executor/model_loader/utils.py | 8 +- 12 files changed, 453 insertions(+), 185 deletions(-) mode change 100644 => 100755 tests/weight_loading/run_model_weight_loading_test.sh diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 92184f43c9eb0..666d87eb92595 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,6 +25,8 @@ #include +#include "core/scalar_type.hpp" + template inline std::string str(T x) { return std::to_string(x); @@ -131,11 +133,26 @@ __device__ inline int lop3(int a, int b, int c) { return res; } -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -156,6 +173,28 @@ __device__ inline FragB dequant(int q) { return frag_b; } +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { @@ -296,7 +335,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; int same_group_id[stages]; auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); @@ -840,10 +902,19 @@ __device__ inline void MarlinMoESingle( // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } - FragB frag_b0 = dequant(b_quant); + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -855,8 +926,6 @@ __device__ inline void MarlinMoESingle( } } - FragB frag_b1 = dequant(b_quant_shift); - // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -881,13 +950,13 @@ __device__ inline void MarlinMoESingle( // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; + constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -1035,8 +1104,10 @@ __device__ inline void MarlinMoESingle( auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here - if constexpr (!has_act_order && group_blocks == -1) { + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { res = __hmul2(res, s[0]); } @@ -1088,9 +1159,9 @@ __device__ inline void MarlinMoESingle( // Start global fetch and register load pipelines. auto start_pipes = [&]() { - // TODO re-enable after fixing this function - // fetch_sorted_ids_to_shared(); - __syncthreads(); + // TODO re-enable after fixing this function + // fetch_sorted_ids_to_shared(); + // __syncthreads(); #pragma unroll for (int i = 0; i < stages - 1; i++) { @@ -1166,28 +1237,70 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } } } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); @@ -1227,7 +1340,8 @@ __device__ inline void MarlinMoESingle( } } -template 4) { + if (max_block > cfg_max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding - par = (16 * max_block - pad) / 64; - par = min((16 * max_block - pad) / 64, max_par); - prob_m = 64 * par; - m_block_ctr += 4 * (par - 1); - max_block = 4; + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; } if (max_block == 1) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, @@ -1342,7 +1457,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template , \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par); \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks); \ } typedef struct { @@ -1423,6 +1543,11 @@ typedef struct { int num_threads; } thread_config_t; +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + thread_config_t small_batch_thread_configs[] = { // Ordered by priority @@ -1443,8 +1568,77 @@ thread_config_t large_batch_thread_configs[] = { {128, 64, 128}, // Reduce N 4X, increase K 2X }; -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = ceildiv(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * STAGES; + } +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = ceildiv(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * STAGES; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1472,64 +1666,88 @@ bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, return false; } + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + return true; } -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } - return thread_config_t{-1, -1, -1}; + return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, const void* topk_ids, const void* s, const void* g_idx, const void* perm, void* a_tmp, void* expert_offsets, int prob_m, int prob_n, int prob_k, void* workspace, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool replicate_input, + bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1537,26 +1755,42 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); } + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int num_bits = q_type.size_bits(); + // Set thread config - thread_config_t th_config; + exec_config_t exec_cfg; if (thread_k != -1 && thread_n != -1) { // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; } else { // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); } - TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + - " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + - str(prob_n) + "]"); - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; @@ -1590,11 +1824,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, } } - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - int tot_m = prob_m; const int* topk_ids_ptr = (const int*)topk_ids; @@ -1611,10 +1840,13 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, has_act_order = false; } + int pack_factor = 32 / q_type.size_bits(); + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const int4* A_ptr = (const int4*)A; int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; + const int4* B_ptr = + (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; @@ -1636,19 +1868,22 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, A_ptr = a_tmp_ptr; } - int max_m_blocks = ceildiv(tot_m, 16); - for (int m_block = 0; m_block < max_m_blocks; m_block += 16) { - // Define kernel configurations - + int tot_m_blocks = ceildiv(tot_m, 16); + for (int m_block = 0; m_block < tot_m_blocks; + m_block += 4 * exec_cfg.max_m_blocks) { // make it max possible value - int thread_m_blocks = 4; + int thread_m_blocks = exec_cfg.max_m_blocks; if (false) { } - CALL_IF_MOE(16, 4, 256) - CALL_IF_MOE(8, 8, 256) - CALL_IF_MOE(8, 4, 128) - CALL_IF_MOE(4, 8, 128) + CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1670,9 +1905,15 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { + TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + + int pack_factor = 32 / b_q_type->size_bits(); + int max_par = 4; int dev = a.get_device(); @@ -1733,8 +1974,8 @@ torch::Tensor marlin_gemm_moe( topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - has_act_order, is_k_full, num_groups, group_size, num_experts, topk, - moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, + topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; } diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 43d264e0770d6..adee8399a4d6f 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -2,11 +2,14 @@ #include +#include "core/scalar_type.hpp" + torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 8a0e625b43fa1..cd65a8ee92b94 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " - "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " - "bool replicate_input, bool apply_weights) -> Tensor"); + "g_idx, Tensor! perm, Tensor! workspace, " + "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2250cf1598b8b..8072cf09e5b65 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -140,6 +140,7 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) +@pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe( m: int, n: int, @@ -148,6 +149,7 @@ def test_fused_marlin_moe( topk: int, group_size: int, act_order: bool, + num_bits: int, ): torch.manual_seed(7) @@ -161,13 +163,12 @@ def test_fused_marlin_moe( if group_size in (k, n): return - quant_type = scalar_types.uint4b8 + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - for i in range(w2.shape[0]): - w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) w_ref1_l = [] qweight1_l = [] @@ -240,6 +241,7 @@ def test_fused_marlin_moe( topk_ids, w1_scale=scales1, w2_scale=scales2, + num_bits=num_bits, ) assert compute_max_diff(marlin_output, triton_output) < 4e-2 @@ -254,7 +256,8 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) -def test_marlin_moe_mmm( +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_single_marlin_moe_multiply( m: int, n: int, k: int, @@ -262,6 +265,7 @@ def test_marlin_moe_mmm( topk: int, group_size: int, act_order: bool, + num_bits: int, ): if topk > e: return @@ -273,7 +277,8 @@ def test_marlin_moe_mmm( if group_size == k: return - quant_type = scalar_types.uint4b8 + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 @@ -308,7 +313,8 @@ def test_marlin_moe_mmm( g_idx, sort_indices, topk, - renormalize=False) + renormalize=False, + num_bits=num_bits) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index fe76705746766..2f5c6c5a117f3 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -1,3 +1,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main -gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main \ No newline at end of file +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main +gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh old mode 100644 new mode 100755 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ed08878f14875..74b3b69606c67 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -559,7 +559,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * 2), + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 200a6148978aa..866b18d725a8c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -7,18 +7,21 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) +from vllm.scalar_type import scalar_types def single_marlin_moe( - hidden_states: torch.Tensor, - w: torch.Tensor, - scales: torch.Tensor, - gating_output: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + num_bits: int = 8, +) -> torch.Tensor: """ This function computes the multiplication of hidden_states with expert weights used in Marlin MoE, using weights w and top-k gating mechanism. @@ -36,6 +39,7 @@ def single_marlin_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -48,10 +52,11 @@ def single_marlin_moe( assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w.is_contiguous(), "Expert weights must be contiguous" assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] M, K = hidden_states.shape E = w.shape[0] - N = w.shape[2] // 2 + N = w.shape[2] // (num_bits // 2) topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) @@ -76,10 +81,13 @@ def single_marlin_moe( device="cuda", requires_grad=False) + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True, - False) + g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, + block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -98,6 +106,7 @@ def fused_marlin_moe( override_config: Optional[Dict[str, Any]] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + num_bits: int = 8, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -122,6 +131,7 @@ def fused_marlin_moe( w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -131,13 +141,14 @@ def fused_marlin_moe( 0], "Number of tokens mismatch" assert hidden_states.shape[ 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] M, K = hidden_states.shape E = w1.shape[0] @@ -165,6 +176,9 @@ def fused_marlin_moe( device="cuda", requires_grad=False) + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), device=hidden_states.device, @@ -181,6 +195,7 @@ def fused_marlin_moe( g_idx1, perm1, workspace, + scalar_type, M, 2 * N, K, @@ -204,6 +219,7 @@ def fused_marlin_moe( g_idx2, perm2, workspace, + scalar_type, M, K, N, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a0cb4337f9dee..3e01112eaa14d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids.to(torch.int32) + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def get_config_dtype_str(dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 49c29c2775cb6..7dee2fca81153 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -6,6 +6,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat) from vllm.model_executor.utils import set_weight_attrs @@ -38,10 +40,11 @@ def __init__( if not (self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits == 4): + and self.num_bits in WNA16_SUPPORTED_BITS): raise ValueError("For Fused MoE layers, only ", f"{CompressionFormat.pack_quantized.value} ", - "is supported for 4 bits") + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, @@ -292,4 +295,5 @@ def apply( topk_ids, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, + num_bits=self.num_bits, ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3617a32f80fc1..cc699f5b4554f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -611,4 +611,5 @@ def apply( topk_ids, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, + num_bits=self.quant_config.quant_type.size_bits, ).to(orig_dtype) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 0052489d99dc4..2bfe6ea09bd62 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,13 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors"] - # for gptq_marlin, only run fused MoE for int4 - if model_config.quantization == "gptq_marlin": - hf_quant_config = getattr(model_config.hf_config, - "quantization_config", None) - if hf_quant_config and hf_quant_config.get("bits") == 4: - mixtral_supported.append("gptq_marlin") + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported From 837c1968f9f1fdd9d894b2071d605ca1bdc97942 Mon Sep 17 00:00:00 2001 From: lewtun Date: Mon, 16 Sep 2024 17:55:26 +0200 Subject: [PATCH 002/116] [Frontend] Expose revision arg in OpenAI server (#8501) --- vllm/entrypoints/openai/api_server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d8704d5e24964..7c1f307e06619 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -69,8 +69,10 @@ def model_is_embedding(model_name: str, trust_remote_code: bool, - quantization: Optional[str]) -> bool: + quantization: Optional[str], + revision: Optional[str]) -> bool: return ModelConfig(model=model_name, + revision=revision, tokenizer=model_name, tokenizer_mode="auto", trust_remote_code=trust_remote_code, @@ -130,7 +132,7 @@ async def build_async_engine_client_from_engine_args( # If manually triggered or embedding model, use AsyncLLMEngine in process. # TODO: support embedding model via RPC. if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, - engine_args.quantization) + engine_args.quantization, engine_args.revision) or disable_frontend_multiprocessing): engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) From acd5511b6d0e196b273b6250201115b5c5cfd1ca Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 16 Sep 2024 17:33:46 +0100 Subject: [PATCH 003/116] [BugFix] Fix clean shutdown issues (#8492) --- tests/async_engine/test_async_llm_engine.py | 10 +- vllm/engine/async_llm_engine.py | 70 +++++--- vllm/engine/llm_engine.py | 21 ++- vllm/entrypoints/launcher.py | 21 ++- vllm/entrypoints/openai/api_server.py | 181 ++++++++++++-------- vllm/entrypoints/openai/rpc/server.py | 8 +- vllm/executor/multiproc_gpu_executor.py | 14 -- vllm/executor/multiproc_worker_utils.py | 5 +- vllm/executor/ray_tpu_executor.py | 2 + vllm/scripts.py | 4 +- vllm/utils.py | 15 ++ 11 files changed, 215 insertions(+), 136 deletions(-) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index a093a2b29278a..6cae76f74603d 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -26,6 +26,11 @@ class RequestOutput: finished: bool = False +@dataclass +class MockModelConfig: + use_async_output_proc = True + + class MockEngine: def __init__(self): @@ -35,6 +40,7 @@ def __init__(self): self.request_id = None # Ugly, remove dependency when possible self.parallel_config = ParallelConfig(1, 1, False) + self.model_config = MockModelConfig() async def step_async(self, virtual_engine): # PP size is 1, ignore virtual engine @@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine): @pytest.mark.asyncio async def test_new_requests_event(): - engine = MockAsyncLLMEngine(worker_use_ray=False) + engine = MockAsyncLLMEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 @@ -113,7 +119,7 @@ async def test_new_requests_event(): assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 - engine = MockAsyncLLMEngine(worker_use_ray=True) + engine = MockAsyncLLMEngine() assert engine.get_model_config() is not None assert engine.get_tokenizer() is not None assert engine.get_decoding_config() is not None diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8a07ce1c965e1..410e6ffaa2d50 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,10 @@ import asyncio import time +import weakref from functools import partial from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) +from weakref import ReferenceType import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, @@ -26,6 +28,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext +from vllm.utils import weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -450,9 +453,6 @@ class AsyncLLMEngine: method yields the outputs from the :class:`LLMEngine` to the caller. Args: - worker_use_ray: Whether to use Ray for model workers. Required for - distributed execution. Should be the same as - `parallel_config.worker_use_ray`. log_requests: Whether to log the requests. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. @@ -463,23 +463,22 @@ class AsyncLLMEngine: _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine def __init__(self, - worker_use_ray: bool, *args, log_requests: bool = True, start_engine_loop: bool = True, **kwargs) -> None: - self.worker_use_ray = worker_use_ray self.log_requests = log_requests self.engine = self._engine_class(*args, **kwargs) # This ensures quick processing of request outputs # so the append to asyncio queues is not delayed, # especially for multi-step. - # - self.use_process_request_outputs_callback = True + self.use_process_request_outputs_callback = ( + self.engine.model_config.use_async_output_proc) + if self.use_process_request_outputs_callback: self.engine.process_request_outputs_callback = \ - self.process_request_outputs + weak_bind(self.process_request_outputs) self.background_loop: Optional[asyncio.Future] = None # We need to keep a reference to unshielded @@ -492,6 +491,11 @@ def __init__(self, # Lazy initialized fields self._request_tracker: RequestTracker + def __del__(self): + if rt := getattr(self, "request_tracker", None): + # Wake up engine loop so that it will exit cleanly + rt.new_requests_event.set() + @classmethod def _get_executor_cls( cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: @@ -502,15 +506,12 @@ def _get_executor_cls( raise TypeError( "distributed_executor_backend must be a subclass of " f"ExecutorAsyncBase. Got {distributed_executor_backend}.") - if distributed_executor_backend.uses_ray: # type: ignore - initialize_ray_cluster(engine_config.parallel_config) executor_class = distributed_executor_backend elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync elif engine_config.device_config.device_type == "tpu": if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync executor_class = RayTPUExecutorAsync else: @@ -531,11 +532,9 @@ def _get_executor_cls( from vllm.executor.xpu_executor import XPUExecutorAsync executor_class = XPUExecutorAsync elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync executor_class = RayXPUExecutorAsync elif distributed_executor_backend == "mp": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.multiproc_xpu_executor import ( MultiprocessingXPUExecutorAsync) executor_class = MultiprocessingXPUExecutorAsync @@ -543,7 +542,6 @@ def _get_executor_cls( raise RuntimeError( "Not supported distributed execution model on XPU device.") elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync elif distributed_executor_backend == "mp": @@ -559,19 +557,23 @@ def _get_executor_cls( def from_engine_args( cls, engine_args: AsyncEngineArgs, + engine_config: Optional[EngineConfig] = None, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. - engine_config = engine_args.create_engine_config() + if engine_config is None: + engine_config = engine_args.create_engine_config() executor_class = cls._get_executor_cls(engine_config) + if executor_class.uses_ray: + initialize_ray_cluster(engine_config.parallel_config) + # Create the async LLM engine. engine = cls( - executor_class.uses_ray, **engine_config.to_dict(), executor_class=executor_class, log_requests=not engine_args.disable_log_requests, @@ -628,7 +630,7 @@ def start_background_loop(self) -> None: self._request_tracker = RequestTracker() self._background_loop_unshielded = asyncio.get_event_loop( - ).create_task(self.run_engine_loop()) + ).create_task(self.run_engine_loop(weakref.ref(self))) self._background_loop_unshielded.add_done_callback( partial(_log_task_completion, error_callback=self._error_callback)) self.background_loop = asyncio.shield(self._background_loop_unshielded) @@ -698,9 +700,16 @@ def process_request_outputs(self, request_outputs) -> bool: async def _engine_abort(self, request_ids: Iterable[str]): self.engine.abort_request(request_ids) - async def run_engine_loop(self): + @staticmethod + async def run_engine_loop(engine_ref: ReferenceType): + """We use a weakref to the engine so that the running loop + doesn't prevent the engine being garbage collected.""" + engine: Optional["AsyncLLMEngine"] = engine_ref() + if not engine: + return + pipeline_parallel_size = \ - self.engine.parallel_config.pipeline_parallel_size + engine.engine.parallel_config.pipeline_parallel_size has_requests_in_progress = [False] * pipeline_parallel_size while True: if not any(has_requests_in_progress): @@ -711,11 +720,21 @@ async def run_engine_loop(self): # timeout, and unblocks the RPC thread in the workers so that # they can process any other queued control plane messages, # such as add/remove lora adapters. - await self.engine.stop_remote_worker_execution_loop_async() - await self._request_tracker.wait_for_new_requests() + await engine.engine.stop_remote_worker_execution_loop_async() + request_tracker = engine._request_tracker + # Allow engine to be garbage collected while + # waiting for new requests + del engine + await asyncio.sleep(0) + if engine_ref() is None: + return + await request_tracker.wait_for_new_requests() + engine = engine_ref() + if not engine: + return logger.debug("Got new requests!") requests_in_progress = [ - asyncio.create_task(self.engine_step(ve)) + asyncio.create_task(engine.engine_step(ve)) for ve in range(pipeline_parallel_size) ] has_requests_in_progress = [True] * pipeline_parallel_size @@ -733,19 +752,20 @@ async def run_engine_loop(self): result = task.result() virtual_engine = requests_in_progress.index(task) has_unfinished_requests = ( - self.engine.has_unfinished_requests_for_virtual_engine( + engine.engine. + has_unfinished_requests_for_virtual_engine( virtual_engine)) if result or has_unfinished_requests: requests_in_progress[virtual_engine] = ( asyncio.create_task( - self.engine_step(virtual_engine))) + engine.engine_step(virtual_engine))) has_requests_in_progress[virtual_engine] = True else: has_requests_in_progress[virtual_engine] = False except asyncio.TimeoutError as exc: logger.error( "Engine iteration timed out. This should never happen!") - self.set_errored(exc) + engine.set_errored(exc) raise await asyncio.sleep(0) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dfdbc22ef00e1..8b5009b2c6668 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,8 +1,8 @@ -import functools import time from collections import deque from contextlib import contextmanager from dataclasses import dataclass +from functools import partial from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence @@ -51,7 +51,7 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device +from vllm.utils import Counter, Device, weak_bind from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -382,11 +382,16 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for _ in range(self.parallel_config.pipeline_parallel_size) ] - self.async_callbacks = [ - functools.partial(self._process_model_outputs, - ctx=self.scheduler_contexts[v_id]) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] + if model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] # Currently used by AsyncLLMEngine to ensure quick append # of request outputs to asyncio queues @@ -869,8 +874,8 @@ def has_unfinished_requests_for_virtual_engine( """ return self.scheduler[virtual_engine].has_unfinished_seqs() + @staticmethod def _process_sequence_group_outputs( - self, seq_group: SequenceGroup, outputs: List[EmbeddingSequenceGroupOutput], ) -> None: diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 3598872b65bb0..47d227010c075 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -1,21 +1,20 @@ import asyncio import signal from http import HTTPStatus -from typing import Any +from typing import Any, Optional import uvicorn -from fastapi import FastAPI, Response +from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError -from vllm.engine.protocol import AsyncEngineClient from vllm.logger import init_logger from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, engine: AsyncEngineClient, +async def serve_http(app: FastAPI, limit_concurrency: Optional[int], **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: @@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, # Set concurrency limits in uvicorn if running in multiprocessing mode # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). - if engine.limit_concurrency is not None: + if limit_concurrency is not None: logger.info( "Launching Uvicorn with --limit_concurrency %s. To avoid this " "limit at the expense of performance run with " - "--disable-frontend-multiprocessing", engine.limit_concurrency) - uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency + "--disable-frontend-multiprocessing", limit_concurrency) + uvicorn_kwargs["limit_concurrency"] = limit_concurrency config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) - _add_shutdown_handlers(app, server, engine) + _add_shutdown_handlers(app, server) loop = asyncio.get_running_loop() @@ -68,15 +67,15 @@ async def dummy_shutdown() -> None: return server.shutdown() -def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server, - engine: AsyncEngineClient) -> None: +def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: """Adds handlers for fatal errors that should crash the server""" @app.exception_handler(RuntimeError) - async def runtime_error_handler(_, __): + async def runtime_error_handler(request: Request, __): """On generic runtime error, check to see if the engine has died. It probably has, in which case the server will no longer be able to handle requests. Trigger a graceful shutdown with a SIGTERM.""" + engine = request.app.state.engine_client if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored and not engine.is_running): logger.fatal("AsyncLLMEngine has failed, terminating server " diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7c1f307e06619..b50fc6a265f8d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,16 +4,20 @@ import multiprocessing import os import re +import signal import tempfile from argparse import Namespace from contextlib import asynccontextmanager +from functools import partial from http import HTTPStatus from typing import AsyncIterator, Optional, Set +import uvloop from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.datastructures import State from starlette.routing import Mount from typing_extensions import assert_never @@ -54,12 +58,6 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds -async_engine_client: AsyncEngineClient -engine_args: AsyncEngineArgs -openai_serving_chat: OpenAIServingChat -openai_serving_completion: OpenAIServingCompletion -openai_serving_embedding: OpenAIServingEmbedding -openai_serving_tokenization: OpenAIServingTokenization prometheus_multiproc_dir: tempfile.TemporaryDirectory # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) @@ -83,18 +81,28 @@ def model_is_embedding(model_name: str, trust_remote_code: bool, @asynccontextmanager async def lifespan(app: FastAPI): - - async def _force_log(): - while True: - await asyncio.sleep(10) - await async_engine_client.do_log_stats() - - if not engine_args.disable_log_stats: - task = asyncio.create_task(_force_log()) - _running_tasks.add(task) - task.add_done_callback(_running_tasks.remove) - - yield + try: + if app.state.log_stats: + async_engine_client = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(10) + await async_engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state @asynccontextmanager @@ -103,16 +111,10 @@ async def build_async_engine_client( # Context manager to handle async_engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit - global engine_args engine_args = AsyncEngineArgs.from_cli_args(args) - # Backend itself still global for the silly lil' health handler - global async_engine_client - async with build_async_engine_client_from_engine_args( engine_args, args.disable_frontend_multiprocessing) as engine: - - async_engine_client = engine # type: ignore[assignment] yield engine @@ -134,12 +136,22 @@ async def build_async_engine_client_from_engine_args( if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, engine_args.quantization, engine_args.revision) or disable_frontend_multiprocessing): - engine_client = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - try: - yield engine_client - finally: - engine_client.shutdown_background_loop() + engine_config = engine_args.create_engine_config() + uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), + "uses_ray", False) + + build_engine = partial(AsyncLLMEngine.from_engine_args, + engine_args=engine_args, + engine_config=engine_config, + usage_context=UsageContext.OPENAI_API_SERVER) + if uses_ray: + # Must run in main thread with ray for its signal handlers to work + engine_client = build_engine() + else: + engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_engine) + + yield engine_client return # Otherwise, use the multiprocessing AsyncLLMEngine. @@ -241,16 +253,36 @@ def mount_metrics(app: FastAPI): app.routes.append(metrics_route) +def chat(request: Request) -> OpenAIServingChat: + return request.app.state.openai_serving_chat + + +def completion(request: Request) -> OpenAIServingCompletion: + return request.app.state.openai_serving_completion + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def embedding(request: Request) -> OpenAIServingEmbedding: + return request.app.state.openai_serving_embedding + + +def engine_client(request: Request) -> AsyncEngineClient: + return request.app.state.engine_client + + @router.get("/health") -async def health() -> Response: +async def health(raw_request: Request) -> Response: """Health check.""" - await async_engine_client.check_health() + await engine_client(raw_request).check_health() return Response(status_code=200) @router.post("/tokenize") -async def tokenize(request: TokenizeRequest): - generator = await openai_serving_tokenization.create_tokenize(request) +async def tokenize(request: TokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_tokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -261,8 +293,8 @@ async def tokenize(request: TokenizeRequest): @router.post("/detokenize") -async def detokenize(request: DetokenizeRequest): - generator = await openai_serving_tokenization.create_detokenize(request) +async def detokenize(request: DetokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_detokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -273,8 +305,8 @@ async def detokenize(request: DetokenizeRequest): @router.get("/v1/models") -async def show_available_models(): - models = await openai_serving_completion.show_available_models() +async def show_available_models(raw_request: Request): + models = await completion(raw_request).show_available_models() return JSONResponse(content=models.model_dump()) @@ -288,7 +320,7 @@ async def show_version(): async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): - generator = await openai_serving_chat.create_chat_completion( + generator = await chat(raw_request).create_chat_completion( request, raw_request) if isinstance(generator, ErrorResponse): @@ -303,7 +335,7 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): - generator = await openai_serving_completion.create_completion( + generator = await completion(raw_request).create_completion( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -316,7 +348,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): - generator = await openai_serving_embedding.create_embedding( + generator = await embedding(raw_request).create_embedding( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -333,16 +365,16 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): "used for local development!") @router.post("/start_profile") - async def start_profile(): + async def start_profile(raw_request: Request): logger.info("Starting profiler...") - await async_engine_client.start_profile() + await engine_client(raw_request).start_profile() logger.info("Profiler started.") return Response(status_code=200) @router.post("/stop_profile") - async def stop_profile(): + async def stop_profile(raw_request: Request): logger.info("Stopping profiler...") - await async_engine_client.stop_profile() + await engine_client(raw_request).stop_profile() logger.info("Profiler stopped.") return Response(status_code=200) @@ -353,13 +385,14 @@ async def stop_profile(): "This should ONLY be used for local development!") @router.post("/v1/load_lora_adapter") - async def load_lora_adapter(request: LoadLoraAdapterRequest): - response = await openai_serving_chat.load_lora_adapter(request) + async def load_lora_adapter(request: LoadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).load_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) - response = await openai_serving_completion.load_lora_adapter(request) + response = await completion(raw_request).load_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) @@ -367,13 +400,14 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest): return Response(status_code=200, content=response) @router.post("/v1/unload_lora_adapter") - async def unload_lora_adapter(request: UnloadLoraAdapterRequest): - response = await openai_serving_chat.unload_lora_adapter(request) + async def unload_lora_adapter(request: UnloadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).unload_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) - response = await openai_serving_completion.unload_lora_adapter(request) + response = await completion(raw_request).unload_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) @@ -398,7 +432,8 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): - err = openai_serving_chat.create_error_response(message=str(exc)) + chat = app.state.openai_serving_chat + err = chat.create_error_response(message=str(exc)) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -430,30 +465,26 @@ async def authentication(request: Request, call_next): return app -async def init_app( +def init_app_state( async_engine_client: AsyncEngineClient, + model_config: ModelConfig, + state: State, args: Namespace, -) -> FastAPI: - app = build_app(args) - +) -> None: if args.served_model_name is not None: served_model_names = args.served_model_name else: served_model_names = [args.model] - model_config = await async_engine_client.get_model_config() - if args.disable_log_requests: request_logger = None else: request_logger = RequestLogger(max_log_len=args.max_log_len) - global openai_serving_chat - global openai_serving_completion - global openai_serving_embedding - global openai_serving_tokenization + state.engine_client = async_engine_client + state.log_stats = not args.disable_log_stats - openai_serving_chat = OpenAIServingChat( + state.openai_serving_chat = OpenAIServingChat( async_engine_client, model_config, served_model_names, @@ -465,7 +496,7 @@ async def init_app( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser) - openai_serving_completion = OpenAIServingCompletion( + state.openai_serving_completion = OpenAIServingCompletion( async_engine_client, model_config, served_model_names, @@ -474,13 +505,13 @@ async def init_app( request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) - openai_serving_embedding = OpenAIServingEmbedding( + state.openai_serving_embedding = OpenAIServingEmbedding( async_engine_client, model_config, served_model_names, request_logger=request_logger, ) - openai_serving_tokenization = OpenAIServingTokenization( + state.openai_serving_tokenization = OpenAIServingTokenization( async_engine_client, model_config, served_model_names, @@ -488,25 +519,31 @@ async def init_app( request_logger=request_logger, chat_template=args.chat_template, ) - app.root_path = args.root_path - - return app async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + async with build_async_engine_client(args) as async_engine_client: # If None, creation of the client failed and we exit. if async_engine_client is None: return - app = await init_app(async_engine_client, args) + app = build_app(args) + + model_config = await async_engine_client.get_model_config() + init_app_state(async_engine_client, model_config, app.state, args) shutdown_task = await serve_http( app, - engine=async_engine_client, + limit_concurrency=async_engine_client.limit_concurrency, host=args.host, port=args.port, log_level=args.uvicorn_log_level, @@ -530,4 +567,4 @@ async def run_server(args, **uvicorn_kwargs) -> None: parser = make_arg_parser(parser) args = parser.parse_args() - asyncio.run(run_server(args)) + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index bebc2faedb680..460ff0636b6e9 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -46,7 +46,6 @@ def cleanup(self): """Cleanup all resources.""" self.socket.close() self.context.destroy() - self.engine.shutdown_background_loop() # Clear the engine reference so that it can be GC'ed. del self.engine @@ -233,5 +232,12 @@ def signal_handler() -> None: def run_rpc_server(async_engine_args: AsyncEngineArgs, usage_context: UsageContext, rpc_path: str): + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("AsyncEngineRPCServer terminated") + + signal.signal(signal.SIGTERM, signal_handler) + server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path) uvloop.run(run_server(server)) diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 9c6d4051eb3f8..cc535e99a06ef 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -1,8 +1,5 @@ import asyncio import os -import signal -import threading -import weakref from functools import partial from typing import Any, List, Optional @@ -108,17 +105,6 @@ def _init_executor(self) -> None: # Set up signal handlers to shutdown the executor cleanly # sometimes gc does not work well - # Use weakref to avoid holding a reference to self - ref = weakref.ref(self) - - def shutdown(signum, frame): - if executor := ref(): - executor.shutdown() - - if threading.current_thread() is threading.main_thread(): - signal.signal(signal.SIGINT, shutdown) - signal.signal(signal.SIGTERM, shutdown) - self.driver_worker = self._create_worker( distributed_init_method=distributed_init_method) self._run_workers("init_device") diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index 28c8e8699f083..aa2a16c04d08d 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -120,7 +120,8 @@ def run(self) -> None: logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode) # Cleanup any remaining workers - logger.info("Killing local vLLM worker processes") + if logger: + logger.info("Killing local vLLM worker processes") for worker in self.workers: worker.kill_worker() # Must be done after worker task queues are all closed @@ -221,6 +222,8 @@ def _run_worker_process( try: executor = getattr(worker, method) output = executor(*args, **kwargs) + except KeyboardInterrupt: + break except BaseException as e: tb = traceback.format_exc() logger.error( diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 732b69d6e5954..d02fecb46f007 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -26,6 +26,8 @@ class RayTPUExecutor(TPUExecutor): + uses_ray: bool = True + def __init__(self, *args, **kwargs): # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. diff --git a/vllm/scripts.py b/vllm/scripts.py index e557961a335bf..231a18e99f3d7 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -1,11 +1,11 @@ # The CLI entrypoint to vLLM. import argparse -import asyncio import os import signal import sys from typing import List, Optional +import uvloop from openai import OpenAI from openai.types.chat import ChatCompletionMessageParam @@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None: # EngineArgs expects the model name to be passed as --model. args.model = args.model_tag - asyncio.run(run_server(args)) + uvloop.run(run_server(args)) def interactive_cli(args: argparse.Namespace) -> None: diff --git a/vllm/utils.py b/vllm/utils.py index aba243071b69a..014fc16a17c1f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -12,6 +12,7 @@ import threading import uuid import warnings +import weakref from asyncio import FIRST_COMPLETED, ensure_future from functools import lru_cache, partial, wraps from platform import uname @@ -1079,6 +1080,20 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) +def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: + """Make an instance method that weakly references + its associated instance and no-ops once that + instance is collected.""" + ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined] + unbound = bound_method.__func__ # type: ignore[attr-defined] + + def weak_bound(*args, **kwargs) -> None: + if inst := ref(): + unbound(inst, *args, **kwargs) + + return weak_bound + + #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f: Callable[P, None]) -> Callable[P, None]: From 781e3b9a4281babf5576f8dc22445af69814d8f6 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Mon, 16 Sep 2024 18:15:57 +0000 Subject: [PATCH 004/116] [Bugfix][Kernel] Fix build for sm_60 in GGUF kernel (#8506) --- csrc/quantization/gguf/vecdotq.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/quantization/gguf/vecdotq.cuh b/csrc/quantization/gguf/vecdotq.cuh index ff339753bcbb5..d5af345a6b26f 100644 --- a/csrc/quantization/gguf/vecdotq.cuh +++ b/csrc/quantization/gguf/vecdotq.cuh @@ -1671,6 +1671,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 const block_iq1_s * bq1 = (const block_iq1_s *) vbq; const int qs_packed = get_int_b2(bq1->qs, iqs); @@ -1697,10 +1698,12 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); const float2 ds = __half22float2(bq8_1[iqs].ds); return d1q * (ds.x*sumi + ds.y*delta); +#endif } static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 const block_iq1_m * bq1 = (const block_iq1_m *) vbq; @@ -1741,6 +1744,7 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( const int sc0 = 2*((tmp >> 0) & 0x07) + 1; const int sc1 = 2*((tmp >> 3) & 0x07) + 1; return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1); +#endif } static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values, From 5d73ae49d65394f8dbe46accd921fb21e8247b82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 16 Sep 2024 14:52:40 -0400 Subject: [PATCH 005/116] [Kernel] AQ AZP 3/4: Asymmetric quantization kernels (#7270) --- csrc/cpu/quant.cpp | 9 +- csrc/cpu/torch_bindings.cpp | 9 +- csrc/ops.h | 6 +- .../compressed_tensors/int8_quant_kernels.cu | 173 ++++++++++++++++-- csrc/torch_bindings.cpp | 8 +- tests/kernels/test_int8_quant.py | 158 ++++++++++++++-- vllm/_custom_ops.py | 29 ++- .../model_executor/layers/quantization/qqq.py | 2 +- .../layers/quantization/utils/w8a8_utils.py | 2 +- 9 files changed, 339 insertions(+), 57 deletions(-) diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 0cfc19097fded..2d7abe6145fee 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -257,11 +257,13 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major // static-per-tensor quantization. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] - const torch::Tensor& scale) { + const torch::Tensor& scale, + c10::optional const& azp) { CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; @@ -277,11 +279,12 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scale // [..., 1] -) { + torch::Tensor& scale, // [..., 1] + c10::optional const& azp) { CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b45da1b386b5b..ab697e3e6aef7 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -94,13 +94,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #ifdef __AVX512F__ // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " - "()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " - "()"); + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); // W8A8 GEMM, supporting symmetric per-tensor or per-row/column diff --git a/csrc/ops.h b/csrc/ops.h index 5333b22c536d6..681ab4b898ca3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -184,10 +184,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor const& scale); + torch::Tensor const& scale, + c10::optional const& azp); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scales); + torch::Tensor& scales, + c10::optional const& azp); torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 616fc149760e5..aec9fa002f96e 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -14,12 +14,17 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM - static const float i8_min = + static constexpr auto i8_min = static_cast(std::numeric_limits::min()); - static const float i8_max = + static constexpr auto i8_max = static_cast(std::numeric_limits::max()); - // round + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); + // saturate dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); @@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #endif } +static inline __device__ int32_t float_to_int32_rn(float x) { +#ifdef USE_ROCM + // int32_max is not exactly representable as float. + // Therefore, we need to be careful and manually return int32_max on overflow. + // For symmetry, we also do the same for int32_min, even though it is exactly + // representable as float and the conversion should be exact. + static constexpr auto i32_min = std::numeric_limits::min(); + static constexpr auto i32_min_f = static_cast(i32_min); + static constexpr auto i32_max = std::numeric_limits::max(); + static constexpr auto i32_max_f = static_cast(i32_max); + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. + float dst = std::nearbyint(x); + + // saturate on the higher end. + if (dst >= i32_max_f) { + return i32_max; + } + // saturate on the lower end. + if (dst <= i32_min_f) { + return i32_min; + } + + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static inline __device__ int8_t int32_to_int8(int32_t x) { +#ifdef USE_ROCM + static constexpr auto i8_min = + static_cast(std::numeric_limits::min()); + static constexpr auto i8_max = + static_cast(std::numeric_limits::max()); + + // saturate + int32_t dst = std::clamp(x, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x)); + return reinterpret_cast(dst); +#endif +} + namespace vllm { template @@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel( } } +template +__global__ void static_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, azp_type const* azp_ptr, + const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; + azp_type const azp = *azp_ptr; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); + out[token_idx * hidden_size + i] = quant_val; + } +} + template __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, @@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel( } } +template +__global__ void dynamic_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, azp_type* azp, const int hidden_size) { + int const token_idx = blockIdx.x; + + // Scan for the min and max value for this token + float max_val = std::numeric_limits::min(); + float min_val = std::numeric_limits::max(); + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto val = static_cast(input[token_idx * hidden_size + i]); + max_val = std::max(max_val, val); + min_val = std::min(min_val, val); + } + + // Reduce the max and min values across the block + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); + __syncthreads(); // Make sure min doesn't mess with max shared memory + min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); + + __shared__ scale_type scale_sh; + __shared__ azp_type azp_sh; + + // Compute the scale and zero point and store them, only on the first thread + if (threadIdx.x == 0) { + float const scale_val = (max_val - min_val) / 255.0f; + // Use rounding to even (same as torch.round) + auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); + auto const azp_val = static_cast(azp_float); + + // Store the scale and azp into shared and global + scale[token_idx] = scale_sh = scale_val; + azp[token_idx] = azp_sh = azp_val; + } + + // Wait for the scale and azp to be computed + __syncthreads(); + + float const scale_val = scale_sh; + azp_type const azp_val = azp_sh; + + // Quantize the values + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = + int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); + out[token_idx * hidden_size + i] = quant_val; + } +} + } // namespace vllm void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& scale) { + torch::Tensor const& scale, + c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp || azp->numel() == 1); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { - vllm::static_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scale.data_ptr(), hidden_size); + if (!azp) { + vllm::static_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), hidden_size); + } else { + vllm::static_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales) { + torch::Tensor& scales, c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(!azp || azp->is_contiguous()); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { - vllm::dynamic_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scales.data_ptr(), hidden_size); + if (!azp) { + vllm::dynamic_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), hidden_size); + } else { + vllm::dynamic_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 51afeacfdc0ad..d7f7547fbef55 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -336,14 +336,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " - "()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " - "()"); + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); } diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index a82ecb026482e..e93cb535d715a 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -13,14 +13,28 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] -def opcheck_int8_quant(output, input, scale=None): - if scale is not None: - opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale)) +def opcheck_int8_quant_static(output, input, scale, azp=None): + if azp is None: + opcheck(torch.ops._C.static_scaled_int8_quant, + (output, input, scale, None)) else: - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale)) + opcheck(torch.ops._C.static_scaled_int8_quant, + (output, input, scale, azp)) + + +def opcheck_int8_quant_dynamic(output, input, symmetric=True): + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + if symmetric: + opcheck(torch.ops._C.dynamic_scaled_int8_quant, + (output, input, scale, None)) + else: + azp = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.int32) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, + (output, input, scale, azp)) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -38,14 +52,56 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, # reference ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) # kernel - ops_out, ops_scales = scaled_int8_quant(x) + ops_out, ops_scales, _ = scaled_int8_quant(x) torch.testing.assert_close(ops_scales, ref_scales) - torch.testing.assert_close( - ops_out, ref_out, atol=1, - rtol=0.0) # big atol to account for rounding errors + # big atol to account for rounding errors + torch.testing.assert_close(ops_out, ref_out, atol=1, rtol=0.0) - opcheck_int8_quant(ops_out, x) + opcheck_int8_quant_dynamic(ops_out, x) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + int8_traits = torch.iinfo(torch.int8) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") * 1000 - 300 + + x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True) + x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True) + + # calculate scale and azp, and adjust the range + scales = (x_token_max - x_token_min) / torch.tensor(255.0) + azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to( + torch.int32) + + torch_out = ((x / scales).round() + azps).clamp( + int8_traits.min, int8_traits.max).to(torch.int8) + assert torch_out.min() >= int8_traits.min and torch_out.max( + ) <= int8_traits.max + + ops_out = torch.empty_like(x, dtype=torch.int8) + scales_out = torch.empty_like(scales, dtype=torch.float32) + azp_out = torch.empty_like(azps, dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out) + + if (not torch.allclose(scales_out, scales)): + print(torch.argmax(torch.abs(scales_out - scales))) + torch.testing.assert_close(scales_out, scales) + # big atol to account for rounding errors + torch.testing.assert_close(azp_out, azps, atol=1, rtol=0.0) + # if AZP is off by 1, after rounding-to-even, the output may be off by 2 + torch.testing.assert_close(ops_out, torch_out, atol=2, rtol=0.0) + + opcheck_int8_quant_dynamic(ops_out, x, False) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -62,14 +118,76 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - scale = torch.tensor([scale], dtype=torch.float32, device="cuda") + scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") + + out1 = (x / scale_arg).round().clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + out2, _, _ = scaled_int8_quant(x, scale_arg) + + # big atol to account for rounding errors + torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) + + opcheck_int8_quant_static(out2, x, scale_arg) - out1 = (x / scale).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - out2, _ = scaled_int8_quant(x, scale) - torch.testing.assert_close( - out1, out2, atol=1, - rtol=0.0) # big atol to account for rounding errors +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE[2:]) # Reduce test time +@pytest.mark.parametrize("azp", [-255, 54]) +@torch.inference_mode() +def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int, + scale: float, azp: int) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + int8_traits = torch.iinfo(torch.int8) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") * 1000 - 300 + + out1 = ((x / scale).round() + azp).clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") + azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda") + + torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg) + + # big atol to account for rounding errors + torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) + + opcheck_int8_quant_static(out2, x, scale_arg, azp_arg) + + +@pytest.mark.parametrize("is_max", [True, False]) +@torch.inference_mode() +def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None: + # Test that the saturating cast works correctly for values near i32 max/min + + from numpy import inf, nextafter + + int32_traits = torch.iinfo(torch.int32) + val = float(int32_traits.max if is_max else int32_traits.min) + + x_vals = [[ + nextafter(val, inf), val + 1, val, val - 1, + nextafter(val, -inf) + ]] + x = torch.tensor(x_vals, dtype=torch.float32, device="cuda") + + # The calculation in the kernel is: cast(cast(x / scale) + azp) + # where cast is a saturating cast to type T. + # Scale is set to 1.0 so that the input values are the ones that are cast. + # AZP is set to 0 to make sure the int8 saturating cast is tested as well. + scale = torch.scalar_tensor(1.0, dtype=torch.float32, device="cuda") + azp = torch.scalar_tensor(0, dtype=torch.int32, device="cuda") + + int8_traits = torch.iinfo(torch.int8) + val_i8 = int8_traits.max if is_max else int8_traits.min + expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda") - opcheck_int8_quant(out2, x, scale) + out = torch.empty_like(expected) + torch.ops._C.static_scaled_int8_quant(out, x, scale, azp) + torch.testing.assert_close(expected, out, atol=0, rtol=0) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 74b3b69606c67..d5b3d7bc6dd5a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -684,32 +684,43 @@ def scaled_fp8_quant( # int8 def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None -) -> Tuple[torch.Tensor, torch.Tensor]: + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ - Quantize the input tensor to int8 and return the quantized tensor and scale. + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. Args: input: The input tensor to be quantized to int8. scale: Optional scaling factor for the int8 quantization. When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales. + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. - torch.ops._C.static_scaled_int8_quant(output, input, scale) - return output, scale + assert symmetric == ( + azp is + None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, None # dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales) - return output, input_scales + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) + return output, input_scales, input_azp # qqq ops diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index c3434214a1cde..5bc3737520865 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -260,7 +260,7 @@ def apply( size_k = x_2d.shape[1] size_n = s_ch.shape[1] - x_int8, s_tok = ops.scaled_int8_quant(x_2d) + x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index a54e3cae73b14..887ee6605560c 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -188,7 +188,7 @@ def apply_int8_linear( # ops.scaled_int8_quant supports both dynamic and static quant. # * dynamic, layer.input_scale is None and x_scale computed from x. # * static, layer.input_scale is scalar and x_scale is input_scale. - x_q, x_scale = ops.scaled_int8_quant(input, input_scale) + x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale) return ops.cutlass_scaled_mm(x_q, weight, From 2759a43a26e4eecb7ff7d741c2b6da0d544462ad Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 16 Sep 2024 12:10:23 -0700 Subject: [PATCH 006/116] [doc] update doc on testing and debugging (#8514) --- docs/source/getting_started/debugging.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst index 31ecca1332e5d..81287762d3c0a 100644 --- a/docs/source/getting_started/debugging.rst +++ b/docs/source/getting_started/debugging.rst @@ -98,6 +98,13 @@ Here are some common issues that can cause hangs: If the script runs successfully, you should see the message ``sanity check is successful!``. + Note that multi-node environment is more complicated than single-node. If you see errors such as ``torch.distributed.DistNetworkError``, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments: + + - In the first node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py``. + - In the second node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py``. + + Adjust ``--nproc-per-node``, ``--nnodes``, and ``--node-rank`` according to your setup. The difference is that you need to execute different commands (with different ``--node-rank``) on different nodes. + If the problem persists, feel free to `open an issue on GitHub `_, with a detailed description of the issue, your environment, and the logs. Some known issues: From 47f5e03b5b9fc719b7e5ee00cbd6d1e79627f105 Mon Sep 17 00:00:00 2001 From: Kevin Lin <42618777+kevin314@users.noreply.github.com> Date: Mon, 16 Sep 2024 15:56:28 -0500 Subject: [PATCH 007/116] [Bugfix] Bind api server port before starting engine (#8491) --- vllm/entrypoints/openai/api_server.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b50fc6a265f8d..3d1d832986c1e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -5,6 +5,7 @@ import os import re import signal +import socket import tempfile from argparse import Namespace from contextlib import asynccontextmanager @@ -525,6 +526,9 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + temp_socket.bind(("", args.port)) + def signal_handler(*_) -> None: # Interrupt server on sigterm while initializing raise KeyboardInterrupt("terminated") @@ -541,6 +545,8 @@ def signal_handler(*_) -> None: model_config = await async_engine_client.get_model_config() init_app_state(async_engine_client, model_config, app.state, args) + temp_socket.close() + shutdown_task = await serve_http( app, limit_concurrency=async_engine_client.limit_concurrency, From 5478c4b41f60995b92b9997306b2e0702055341f Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 16 Sep 2024 14:30:02 -0700 Subject: [PATCH 008/116] [perf bench] set timeout to debug hanging (#8516) --- .buildkite/nightly-benchmarks/benchmark-pipeline.yaml | 3 +-- .buildkite/nightly-benchmarks/scripts/wait-for-image.sh | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 2b70e2da5d87c..eec2a51e2f8fd 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -8,8 +8,7 @@ steps: containers: - image: badouralix/curl-jq command: - - sh - - .buildkite/nightly-benchmarks/scripts/wait-for-image.sh + - sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh - wait - label: "A100" agents: diff --git a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh index c785e6a0da628..f16862907def1 100644 --- a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh +++ b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh @@ -2,9 +2,11 @@ TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-test-repo:pull" | jq -r .token) URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-test-repo/manifests/$BUILDKITE_COMMIT" +TIMEOUT_SECONDS=10 + retries=0 while [ $retries -lt 1000 ]; do - if [ $(curl -s -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then + if [ $(curl -s --max-time $TIMEOUT_SECONDS -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then exit 0 fi From 5ce45eb54d3fb870f1fb6865c67aac05ec9bf555 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 16 Sep 2024 15:11:27 -0700 Subject: [PATCH 009/116] [misc] small qol fixes for release process (#8517) --- Dockerfile | 2 ++ setup.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 5484be5bc5785..620f549cf3955 100644 --- a/Dockerfile +++ b/Dockerfile @@ -82,6 +82,7 @@ ENV BUILDKITE_COMMIT=${buildkite_commit} ARG USE_SCCACHE ARG SCCACHE_BUCKET_NAME=vllm-build-sccache ARG SCCACHE_REGION_NAME=us-west-2 +ARG SCCACHE_S3_NO_CREDENTIALS=0 # if USE_SCCACHE is set, use sccache to speed up compilation RUN --mount=type=cache,target=/root/.cache/pip \ if [ "$USE_SCCACHE" = "1" ]; then \ @@ -92,6 +93,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \ && export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \ && export SCCACHE_REGION=${SCCACHE_REGION_NAME} \ + && export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \ && export SCCACHE_IDLE_TIMEOUT=0 \ && export CMAKE_BUILD_TYPE=Release \ && sccache --show-stats \ diff --git a/setup.py b/setup.py index 8930ea7239dc9..7da9115440433 100644 --- a/setup.py +++ b/setup.py @@ -371,7 +371,9 @@ def get_vllm_version() -> str: cuda_version = str(get_nvcc_cuda_version()) if cuda_version != MAIN_CUDA_VERSION: cuda_version_str = cuda_version.replace(".", "")[:3] - version += f"+cu{cuda_version_str}" + # skip this for source tarball, required for pypi + if "sdist" not in sys.argv: + version += f"+cu{cuda_version_str}" elif _is_hip(): # Get the HIP version hipcc_version = get_hipcc_rocm_version() From cca61642e0484212e6cd78b35b4789afed8d19c6 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 16 Sep 2024 18:01:45 -0600 Subject: [PATCH 010/116] [Bugfix] Fix 3.12 builds on main (#8510) Signed-off-by: Joe Runde --- Dockerfile | 4 ---- requirements-common.txt | 1 + 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index 620f549cf3955..001068b4b36ca 100644 --- a/Dockerfile +++ b/Dockerfile @@ -182,10 +182,6 @@ FROM vllm-base AS test ADD . /vllm-workspace/ # install development dependencies (for testing) -# A newer setuptools is required for installing some test dependencies from source that do not publish python 3.12 wheels -# This installation must complete before the test dependencies are collected and installed. -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install "setuptools>=74.1.1" RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt diff --git a/requirements-common.txt b/requirements-common.txt index ad950d0313454..ad53395307ec5 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -29,4 +29,5 @@ importlib_metadata mistral_common >= 1.4.0 pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 +setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. From 546034b466bf11f12936791312981b9982850eb0 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 16 Sep 2024 20:04:48 -0700 Subject: [PATCH 011/116] [refactor] remove triton based sampler (#8524) --- tests/kernels/test_rand.py | 52 --- tests/kernels/test_sampler.py | 209 ----------- vllm/model_executor/layers/ops/__init__.py | 0 vllm/model_executor/layers/ops/rand.py | 157 -------- vllm/model_executor/layers/ops/sample.py | 394 --------------------- vllm/model_executor/layers/sampler.py | 97 +---- vllm/model_executor/sampling_metadata.py | 211 +++-------- vllm/triton_utils/sample.py | 13 - vllm/utils.py | 37 +- 9 files changed, 75 insertions(+), 1095 deletions(-) delete mode 100644 tests/kernels/test_rand.py delete mode 100644 tests/kernels/test_sampler.py delete mode 100644 vllm/model_executor/layers/ops/__init__.py delete mode 100644 vllm/model_executor/layers/ops/rand.py delete mode 100644 vllm/model_executor/layers/ops/sample.py delete mode 100644 vllm/triton_utils/sample.py diff --git a/tests/kernels/test_rand.py b/tests/kernels/test_rand.py deleted file mode 100644 index a4242d22eb489..0000000000000 --- a/tests/kernels/test_rand.py +++ /dev/null @@ -1,52 +0,0 @@ -import random - -import pytest -import torch - -from vllm.model_executor.layers.ops.rand import seeded_uniform -from vllm.model_executor.utils import set_random_seed - - -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("use_3d", [True, False]) -def test_seeded_uniform(dtype: torch.dtype, use_3d: bool): - device = "cuda" - for seed in range(512): - set_random_seed(seed) - rows = random.randint(1, 512) - cols = random.randint(1, 64000) - if use_3d: - third_dim = random.randint(2, 10) - dims = [rows, third_dim, cols] - else: - dims = [rows, cols] - seeds = torch.randint(torch.iinfo(torch.long).min, - torch.iinfo(torch.long).max, (rows, ), - device=device) - - # Test that the same seed produces the same output - out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - torch.testing.assert_close(out, out2) - # del to save memory - del out2 - - out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - torch.testing.assert_close(out, out3) - # del to save memory - del out3 - - # Initialize out tensor with garbage to ensure that it is overwritten - out_with_tensor = seeded_uniform( - *dims, - out=torch.full( - (*dims, ), - -1, - dtype=dtype, - device=device, - ), - seeds=seeds, - dtype=dtype, - ) - torch.testing.assert_close(out, out_with_tensor) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py deleted file mode 100644 index 03844aba20f8a..0000000000000 --- a/tests/kernels/test_sampler.py +++ /dev/null @@ -1,209 +0,0 @@ -import gc -from unittest.mock import patch - -import pytest -import torch -import triton -import triton.language as tl - -from vllm.model_executor.layers.ops.sample import (_sample_triton, - _uniform_to_exponential, - sample) -from vllm.model_executor.sampling_metadata import SamplingTensors -from vllm.model_executor.utils import set_random_seed -from vllm.triton_utils.libentry import LibEntry -from vllm.triton_utils.sample import (MAX_TRITON_N_COLS, - get_num_triton_sampler_splits) - -SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size -MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 - - -@pytest.fixture(autouse=True) -def _cleanup(): - yield - gc.collect() - torch.cuda.empty_cache() - - -@triton.jit -def _uniform_to_exponential_kernel(input, output, n: tl.constexpr): - idx = tl.arange(0, n) - x = tl.load(input + idx) - y = _uniform_to_exponential(x) - tl.store(output + idx, y) - - -def test_uniform_to_exponential(): - """Test that we can convert uniform to exponential without div by 0.""" - input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps], - dtype=torch.float32, - device="cuda") - output = torch.zeros(input.shape, dtype=torch.float32, device="cuda") - _uniform_to_exponential_kernel[(1, )](input, output, 2) - assert torch.all(torch.isfinite(output)) - assert torch.all(output > 0) - assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output)) - - -@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) -@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) -@pytest.mark.parametrize("modify_greedy_probs", [True, False]) -@pytest.mark.parametrize("seed", [1337]) -@pytest.mark.parametrize("vocab_size", - [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) -@pytest.mark.parametrize("save_logprobs", [True, False]) -def test_sample_decoding_only(random_sampling, max_best_of, - modify_greedy_probs, seed, vocab_size, - save_logprobs): - set_random_seed(seed) - bs = 8 - probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") - for i in range(bs): - probs[i, i * (vocab_size // bs)] = 1.0 - logprobs = torch.rand_like(probs) - sample_indices = torch.arange(bs, dtype=torch.long, device="cuda") - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if random_sampling == "mixed": - random_sampling_mask = (torch.rand( - (1, bs), device="cuda") < 0.5).expand(n_splits, bs) - elif random_sampling: - random_sampling_mask = torch.ones((n_splits, bs), - dtype=torch.bool, - device="cuda") - else: - random_sampling_mask = torch.zeros((n_splits, bs), - dtype=torch.bool, - device="cuda") - - seeds = torch.randint(1, - torch.iinfo(torch.long).max, (n_splits, bs), - device="cuda").mul_(random_sampling_mask) - #The current _sample_triton does not utilize the - # libentry decoration. The purpose of adding this patch is to test - # the correctness of libentry. - with patch("vllm.model_executor.layers.ops.sample._sample_triton", - LibEntry(_sample_triton)): - sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - _save_modified_probs=True) - assert sampled_tokens.shape == (bs, max_best_of) - for i in range(bs): - assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) - request_uses_random_sampling = random_sampling_mask[0, i] - if modify_greedy_probs and not request_uses_random_sampling: - # If we are modifying greedy probs and the request is greedy, - # we want to make sure the probs tensor is modified in place - torch.testing.assert_close( - probs[i][sampled_tokens[i]], - torch.full_like(probs[i][sampled_tokens[i]], 1.0)) - assert torch.sum(probs[i]) == 1.0 - torch.testing.assert_close( - sampled_modified_probs[i][0], - torch.full_like(sampled_modified_probs[i][0], 1.0)) - elif request_uses_random_sampling: - # If the request is random, we want to make sure - # sampled_modified_probs tensor has noise added - # (and thus is different from probs tensor) - assert not torch.allclose(sampled_modified_probs[i][0], - probs[i][sampled_tokens[i]]) - elif not request_uses_random_sampling: - # If the request is greedy and we are not modifying greedy probs, - # we want to make sure sampled_modified_probs tensor is the same as - # the probs tensor. - torch.testing.assert_close(sampled_modified_probs[i], - probs[i][sampled_tokens[i]]) - - if save_logprobs: - assert sampled_logprobs.shape == (bs, max_best_of) - for i in range(bs): - for best_of in range(max_best_of): - assert torch.all(sampled_logprobs[i] == logprobs[i][ - sampled_tokens[i, best_of]]) - else: - assert sampled_logprobs is None - - -@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) -@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) -@pytest.mark.parametrize("modify_greedy_probs", [True, False]) -@pytest.mark.parametrize("seed", [1337]) -@pytest.mark.parametrize("vocab_size", - [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) -def test_sample_prompt_logprobs(random_sampling, max_best_of, - modify_greedy_probs, seed, vocab_size): - - set_random_seed(seed) - prompt_sizes = [16, 32, 64, 128] * 2 - samples = 8 - bs = samples + sum(prompt_sizes) - probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") - for i in range(bs): - probs[i, i * (vocab_size // bs)] = 1.0 - logprobs = torch.rand_like(probs) - sample_indices = torch.tensor(prompt_sizes, - dtype=torch.long, - device="cuda").cumsum_(0) - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if random_sampling == "mixed": - random_sampling_mask = torch.rand( - (n_splits, samples), device="cuda") < 0.5 - elif random_sampling: - random_sampling_mask = torch.ones((n_splits, samples), - dtype=torch.bool, - device="cuda") - else: - random_sampling_mask = torch.zeros((n_splits, samples), - dtype=torch.bool, - device="cuda") - - seeds = torch.randint(1, - torch.iinfo(torch.long).max, (n_splits, samples), - device="cuda").mul_(random_sampling_mask) - #ditto - with patch("vllm.model_executor.layers.ops.sample._sample_triton", - LibEntry(_sample_triton)): - sampled_tokens, sampled_logprobs, _ = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=True) - assert sampled_tokens.shape == (samples, max_best_of) - assert sampled_logprobs.shape == (samples, max_best_of) - for i, t in enumerate(sample_indices): - assert torch.all(sampled_tokens[i] == t * (vocab_size // bs)) - for best_of in range(max_best_of): - assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]] - [sampled_tokens[i, best_of]]) - - -@pytest.mark.parametrize("seed", list(range(16))) -def test_get_sequence_seeds(seed): - """Ensure that we get a different child seed from base - seed + extra entropy""" - starting_seed = seed - seq_seed = None - extra_entropy = 1 - for i in range(512): - new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed, - i, - seeds_to_generate=1, - is_greedy=False)[0] - new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds( - starting_seed, - i, - extra_entropy, - seeds_to_generate=1, - is_greedy=False)[0] - assert new_seq_seed_extra_entropy != new_seq_seed - assert seq_seed != new_seq_seed - seq_seed = new_seq_seed diff --git a/vllm/model_executor/layers/ops/__init__.py b/vllm/model_executor/layers/ops/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/model_executor/layers/ops/rand.py b/vllm/model_executor/layers/ops/rand.py deleted file mode 100644 index 4a429e329567d..0000000000000 --- a/vllm/model_executor/layers/ops/rand.py +++ /dev/null @@ -1,157 +0,0 @@ -from typing import Optional, Union - -import torch -import triton -import triton.language as tl - - -def seeded_uniform( - *size, - seeds: torch.Tensor, - out: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[Union[torch.device, str]] = None, - pin_memory: Optional[bool] = False, -) -> torch.Tensor: - """Similar to torch.rand, but allows for seeds to be set per row. - - seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d. - If it is 3d, the additional seeds needed will be derived automatically - in a deterministic fashion: - [ - row 0: [columns_with_seed_0], [columns_with_seed0^1], ... - ] - """ - n_dims = len(size) - - if n_dims > 3: - raise ValueError("seeded_uniform only supports up to 3D tensors") - - if out is None: - out = torch.empty(*size, - dtype=dtype, - device=device, - pin_memory=pin_memory) - elif out.shape != size: - raise ValueError("shape of out and size must be the same") - - if n_dims == 3: - n_rows, n_3d, n_cols = out.shape - stride_row = out.stride(0) - stride_3d = out.stride(1) - elif n_dims == 2: - n_rows, n_cols = out.shape - n_3d = 1 - stride_row = out.stride(0) - stride_3d = 1 - else: - n_cols = out.shape[0] - n_rows = 1 - n_3d = 1 - stride_row = 1 - stride_3d = 1 - - if seeds.ndim != 1: - raise ValueError("seeds must be a 1D tensor") - - if seeds.numel() != n_rows: - raise ValueError( - "seeds must have the same number of elements as out has rows") - - # The philox PRNG Triton uses generates 4 random numbers at once. - # Therefore, the most efficient use of it is to divide the - # block size by 4, and then save the generated random numbers to - # each of the 4 slices of the tensor. - full_block_size = triton.next_power_of_2(n_cols) - philox_block_size = max(full_block_size // 4, 1) - n_slices = full_block_size // philox_block_size - num_warps = 4 - # Manual tuning. This seems to give best performance on A100 for - # simple kernels like this. - if philox_block_size >= 8192: - num_warps = 32 - elif philox_block_size >= 4096: - num_warps = 16 - elif philox_block_size >= 2048: - num_warps = 8 - - _seeded_uniform_triton[(n_rows, n_3d)]( - out, - seeds, - stride_row, - stride_3d, - seeds.stride(0), - n_rows, - n_3d, - n_cols, - n_slices=n_slices, - num_warps=num_warps, - block_size=philox_block_size, - ) - return out - - -@triton.jit -def _seeded_uniform_triton( - out_ptr: torch.Tensor, - seed_ptr: torch.Tensor, - out_row_stride: int, - out_3d_stride: int, - seed_row_stride: int, - n_rows: int, - n_3d: int, - n_cols: int, - n_slices: tl.constexpr, - block_size: tl.constexpr, -): - """ - Generate a random float32 number in [0, 1) for each element in the output - tensor. The random numbers in a row generated using the seed for that row. - - Args: - out_ptr: The output tensor. - seed_ptr: The per-row seeds to use for random number generation. - out_row_stride: The stride between rows of the output tensor. - out_3d_stride: The stride between 3D slices of the output tensor. - seed_row_stride: The stride between rows of the seed tensor. - n_rows: The number of rows in the output tensor. - n_3d: The size of second dimension of the output tensor, - if output tensor is 3D. - n_cols: The number of columns in the output tensor. - n_slices: The number of philox outputs to use. - """ - tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4") - - # Get the row index. - row_idx = tl.program_id(axis=0) - three_d_idx = tl.program_id(axis=1) - - philox_offsets = tl.arange(0, block_size) - # Get the seed for the current element. - seed = tl.load(seed_ptr + row_idx * seed_row_stride) - if three_d_idx > 0: - seed ^= three_d_idx - # Generate random numbers in [0, 1). - out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets) - - output_row_start_ptr = (out_ptr + row_idx * out_row_stride + - three_d_idx * out_3d_stride) - out1_offsets = philox_offsets - tl.store(output_row_start_ptr + out1_offsets, - out1, - mask=out1_offsets < n_cols) - if n_slices > 1: - out2_offsets = tl.arange(block_size, block_size * 2) - tl.store(output_row_start_ptr + out2_offsets, - out2, - mask=out2_offsets < n_cols) - if n_slices > 2: - out3_offsets = tl.arange(block_size * 2, block_size * 3) - tl.store(output_row_start_ptr + out3_offsets, - out3, - mask=out3_offsets < n_cols) - if n_slices > 3: - out4_offsets = tl.arange(block_size * 3, block_size * 4) - tl.store(output_row_start_ptr + out4_offsets, - out4, - mask=out4_offsets < n_cols) diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py deleted file mode 100644 index fb88a05daf482..0000000000000 --- a/vllm/model_executor/layers/ops/sample.py +++ /dev/null @@ -1,394 +0,0 @@ -from typing import Optional, Tuple - -import torch -import triton -import triton.language as tl - -from vllm.model_executor.layers.ops.rand import seeded_uniform -from vllm.triton_utils.sample import get_num_triton_sampler_splits - -_EPS: tl.constexpr = 1e-6 - - -def _multi_split_sample( - probs: torch.Tensor, - seeds: torch.Tensor, - n_splits: int, - sampled_tokens_size: Tuple[int, int], - sampled_logprobs_size: Tuple[int, int], - sample_indices: torch.Tensor, - logprobs: torch.Tensor, - *, - modify_greedy_probs: bool = False, - save_logprobs: bool = False, -): - """Sample tokens where vocab size is split into multiple parts - (too large for Triton otherwise).""" - assert seeds.ndim == 2 and seeds.shape[0] == n_splits - split_probs = probs.tensor_split(n_splits, 1) - split_logprobs = logprobs.tensor_split(n_splits, 1) - sampled_tokens_tmp = [ - torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device) - for _ in range(n_splits) - ] - sampled_logprobs_tmp = [ - torch.empty(sampled_logprobs_size, - dtype=probs.dtype, - device=probs.device) for _ in range(n_splits) - ] - # We are purposefuly using sampled_tokens_size as we need to always - # save modified probs in this case. - sampled_modified_probs_tmp = [ - torch.empty(sampled_tokens_size, - dtype=probs.dtype, - device=probs.device) for _ in range(n_splits) - ] - for i in range(n_splits): - n_samples = sample_indices.shape[0] - n_cols = split_probs[i].shape[1] - n_best = sampled_tokens_tmp[i].shape[1] - uniform_noise = seeded_uniform(n_samples, - n_best, - n_cols, - seeds=seeds[i].flatten(), - device=split_probs[i].device, - dtype=split_probs[i].dtype) - # TODO(yard1): See if we can remove the contiguous() calls. - # Will need kernel support. - _sample( - split_probs[i].contiguous(), - split_logprobs[i].contiguous(), - sample_indices, - sampled_tokens_tmp[i], - sampled_logprobs_tmp[i], - sampled_modified_probs_tmp[i], - seeds[i], - uniform_noise, - modify_greedy_probs=False, - save_logprobs=save_logprobs, - save_modified_probs=True, - ) - if i > 0: - # Add offset to sampled tokens - sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1]) - sampled_tokens = torch.stack(sampled_tokens_tmp) - sampled_modified_probs = torch.stack(sampled_modified_probs_tmp) - # Reduce the results from the splits. - sampled_modified_probs, indices = torch.max(sampled_modified_probs, - dim=0, - keepdim=True) - sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0) - if save_logprobs: - sampled_logprobs = torch.stack(sampled_logprobs_tmp) - sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0) - else: - sampled_logprobs = None - sampled_modified_probs = sampled_modified_probs.squeeze(0) - - if modify_greedy_probs: - # We need to modify the greedy probs for the sampled tokens. - # We can't do this in the kernel as we need to know the - # sampled tokens. - probs.fill_(0.0) - probs.scatter_(1, sampled_tokens, 1.0) - - return (sampled_tokens, sampled_logprobs, sampled_modified_probs) - - -def sample( - probs: torch.Tensor, - seeds: torch.Tensor, - *, - max_best_of: int = 1, - sample_indices: Optional[torch.Tensor] = None, - logprobs: Optional[torch.Tensor] = None, - modify_greedy_probs: bool = False, - save_logprobs: bool = False, - _save_modified_probs: bool = False, # pylint: disable=invalid-name -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - """Sample tokens from probs. with per-sequence seeds. - - Can sample from a subset of sequences through sample_indices. - - Args: - probs: Probabilities to sample from. - shape = [batch_size, vocab_size] - seeds: Per-sequence seed values. - shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)] - max_best_of: Number of samples to generate per sequence. - Sequence seed will be incremented by 1 each time. - sample_indices: Indices of sequences to sample from. - If not provided, will sample from all sequences. - shape = [n] - logprobs: Log-probabilities of the sampled tokens. - Only used for saving the logprobs if save_logprobs is True. - shape = [batch_size, vocab_size] - modify_greedy_probs: Whether to modify the greedy probabilities - for speculative sampling (sampled token = 1.0, - everything else = 0.0). - save_logprobs: Whether to save the log-probabilities of the - sampled tokens to a tensor. - _save_modified_probs: Whether to save the modified probabilities - (including gumbel noise) of the sampled tokens to a tensor. - DOES NOT include the modification done by modify_greedy_probs - (because we want to use the unmodified probs to pick the best - split in case of multi-split sampling). - This is exposed only for testing. - - Returns: - sampled_tokens: shape = [n, max_best_of] - sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None - sampled_modified_probs: shape = [n, max_best_of] - if save_modified_probs else None - """ - if sample_indices is None: - sample_indices = torch.arange(0, probs.shape[0], device=probs.device) - - sampled_tokens_size = (sample_indices.size(0), max_best_of) - if save_logprobs: - if logprobs is None: - raise ValueError( - "logprobs tensor must be provided if save_logprobs is True") - sampled_logprobs_size = sampled_tokens_size - else: - # Empty tensors to invoke the kernel - sampled_logprobs_size = (0, 0) - logprobs = probs - - assert logprobs is not None - if _save_modified_probs: - sampled_modified_probs_size = sampled_tokens_size - else: - # Empty tensors to invoke the kernel - sampled_modified_probs_size = (0, 0) - - # If the number of columns in probs is too large for Triton to handle, - # we split the tensor and sample from each split separately, and then - # do an argmax+gather to combine the results. - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if n_splits > 1: - (sampled_tokens, sampled_logprobs, - sampled_modified_probs) = _multi_split_sample( - probs, - seeds, - n_splits, - sampled_tokens_size, - sampled_logprobs_size, - sample_indices, - logprobs=logprobs, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs) - else: - sampled_tokens = torch.empty(sampled_tokens_size, - dtype=torch.long, - device=probs.device) - sampled_logprobs = torch.empty(sampled_logprobs_size, - dtype=probs.dtype, - device=probs.device) - sampled_modified_probs = torch.empty(sampled_modified_probs_size, - dtype=probs.dtype, - device=probs.device) - n_samples = sample_indices.shape[0] - n_cols = probs.shape[1] - uniform_noise = seeded_uniform(n_samples, - max_best_of, - n_cols, - seeds=seeds.flatten(), - device=probs.device, - dtype=probs.dtype) - - _sample( - probs, - logprobs, - sample_indices, - sampled_tokens, - sampled_logprobs, - sampled_modified_probs, - seeds, - uniform_noise, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - save_modified_probs=_save_modified_probs, - ) - return (sampled_tokens, sampled_logprobs if save_logprobs else None, - sampled_modified_probs if _save_modified_probs else None) - - -def _sample(probs: torch.Tensor, - logprobs: torch.Tensor, - sample_indices: torch.Tensor, - output_samples: torch.Tensor, - output_logprobs: torch.Tensor, - output_modified_probs: torch.Tensor, - seeds: torch.Tensor, - uniform_noise: torch.Tensor, - *, - modify_greedy_probs: bool = False, - save_logprobs: bool = True, - save_modified_probs: bool = False) -> torch.Tensor: - """Sample tokens from probs. - - Args: - probs [batch_size, vocab_size]: probs to sample from. - logprobs [batch_size, vocab_size]: logprobs (used when - save_logprobsis True). - sample_indices [n]: Indices of the samples to use for each row of probs. - output_samples [n, n_best]: Output tensor to store samples in. - output_logprobs [n, n_best]: Output tensor to store logprobs in. - output_modified_probs [n, n_best]: Output tensor to store - probs of chosen tokens in (modified with noise). - seeds [n]: Seeds to use for sampling. If the seed is 0, we use - greedy sampling. Note this is ONLY used for determining - whether to use random sampling or not. The actual random - noise should be passed as uniform_noise. - uniform_noise [batch_size, n_best, vocab_size]: Uniform - noise to use for random sampling (will be converted - to exponential gumbel noise by the kernel). - modify_greedy_probs: If True, we modify the probs tensor in-place - to encode the sampling method used for each row. This is used - in speculative decoding. Only applies in greedy decoding. - save_logprobs: If True, we save the logprobs of the sampled tokens - in the output_logprobs tensor. - save_modified_probs: If True, we save the modified probs (with noise) - of the sampled tokens in the output_modified_probs tensor. - DOES NOT include the modification done by modify_greedy_probs - (because we want to use the unmodified probs to pick the best - split in case of multi-split sampling). - """ - n_samples = sample_indices.shape[0] - n_cols = probs.shape[1] - n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1 - - # The block size is the smallest power of two greater than the number of - # columns in probs - block_size = triton.next_power_of_2(n_cols) - num_warps = 4 - # Manual tuning. This seems to give best performance on A100 for - # simple kernels like this. - if block_size >= 8192: - num_warps = 32 - elif block_size >= 4096: - num_warps = 16 - elif block_size >= 2048: - num_warps = 8 - - # Enqueue kernel. The 1D launch grid is simple: we have one kernel - # instance per row of the probs matrix - _sample_triton[(n_samples, n_best)]( - sample_indices, - output_samples, - output_logprobs, - output_modified_probs, - probs, - logprobs, - seeds, - uniform_noise, - output_samples.stride(0), - probs.stride(0), - uniform_noise.stride(0), - uniform_noise.stride(1) if n_best > 1 else 1, - n_samples, - n_cols, - n_best, - num_warps=num_warps, - block_size=block_size, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - save_modified_probs=save_modified_probs, - ) - return output_samples, output_logprobs, output_modified_probs - - -@triton.jit -def _uniform_to_exponential(uniform_noise): - """Convert uniform samples to exponential samples.""" - # tl.rand returns values in [0, 1), so we clamp lower bound - # to _EPS to avoid log(0) and thus division by 0 later - lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) - uniform_noise = tl.maximum(uniform_noise, lb) - # Use the inversion method to turn uniform samples - # into exponential samples - exponential_noise = -tl.log(uniform_noise) - return exponential_noise - - -@triton.jit -def _sample_triton( - sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, - output_logprobs_ptr: torch.Tensor, - output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, - logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor, - uniform_noise_ptr: torch.Tensor, output_row_stride: int, - probs_row_stride: int, uniform_noise_row_stride: int, - uniform_noise_best_stride: int, n_samples: int, n_cols: int, - n_best: int, block_size: tl.constexpr, - modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, - save_modified_probs: tl.constexpr): - # The rows are independent, so we parallelize across those - sample_idx = tl.program_id(0) - best_idx = tl.program_id(1) - - # Load the row index from DRAM - row_idx = tl.load(sample_indices_ptr + sample_idx) - seed = tl.load(seeds_ptr + sample_idx) - uses_random_sampling = seed != 0 - - # The stride represents how much we need to increase the - # pointer to advance 1 row - row_start_ptr = probs_ptr + row_idx * probs_row_stride - - # The block size is the next power of two greater than n_cols, - # so we can fit each row in a single block - col_offsets = tl.arange(0, block_size) - - # Load the row into SRAM, using a mask since block_size may be > than n_cols - row = tl.load(row_start_ptr + col_offsets, - mask=col_offsets < n_cols, - other=float("-inf")) - - if uses_random_sampling: - uniform_noise_start_ptr = (uniform_noise_ptr + - sample_idx * uniform_noise_row_stride + - best_idx * uniform_noise_best_stride) - uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets, - mask=col_offsets < n_cols, - other=0.5) - exponential_noise = _uniform_to_exponential(uniform_noise) - row /= exponential_noise - - sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True) - # clamp sampled token to n_cols - 1 - # this should not be necessary, but we do it - # just in case - if sampled_token >= n_cols: - sampled_token = n_cols - 1 - # Write back output to DRAM - output_row_start_ptr = (output_ptr + sample_idx * output_row_stride + - best_idx) - tl.store(output_row_start_ptr, sampled_token) - - if modify_greedy_probs: # noqa - if not uses_random_sampling: - # Set the probability of the sampled token to 1, all other - # tokens to zero. This is used in speculative decoding where - # the sampling method must be encoded within the sampled - # probability distributions. - row = tl.where(col_offsets == sampled_token, 1.0, 0.0) - tl.store(row_start_ptr + col_offsets, - row, - mask=col_offsets < n_cols) - - if save_modified_probs: - output_row_start_ptr = (output_modified_probs_ptr + - sample_idx * output_row_stride + best_idx) - tl.store(output_row_start_ptr, sampled_value) - - if save_logprobs: - # Load the row into SRAM, using a mask since block_size - # may be > than n_cols - sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride + - sampled_token) - # Write back output to DRAM - output_row_start_ptr = (output_logprobs_ptr + - sample_idx * output_row_stride + best_idx) - tl.store(output_row_start_ptr, sampled_logprob) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c00da106734ae..487f5a3d2a441 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -10,12 +10,6 @@ import torch import torch.nn as nn -from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.model_executor.layers.ops.sample import sample as sample_triton - import vllm.envs as envs from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, @@ -23,6 +17,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -740,7 +735,7 @@ def _sample_with_torch( ) -> SampleReturnType: '''Torch-oriented _sample() implementation. - Single-step scheduling: + Single-step scheduling: * Perform GPU-side sampling computation * Immediately Pythonize sampling result @@ -777,7 +772,7 @@ def _sample_with_torch( # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type][:, 0] + sample_indices = categorized_sample_indices[sampling_type] num_tokens = len(sample_indices) if num_tokens == 0: continue @@ -863,88 +858,6 @@ def _sample_with_torch( ) -def _sample_with_triton_kernel( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, -) -> SampleResultType: - categorized_seq_group_ids: Dict[SamplingType, - List[int]] = {t: [] - for t in SamplingType} - categorized_sample_indices = sampling_metadata.categorized_sample_indices - for i, seq_group in enumerate(sampling_metadata.seq_groups): - sampling_params = seq_group.sampling_params - sampling_type = sampling_params.sampling_type - categorized_seq_group_ids[sampling_type].append(i) - - sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} - sample_metadata: Dict[SamplingType, - Tuple[List[int], List[SequenceGroupToSample], - torch.Tensor, torch.Tensor]] = {} - max_best_of_in_batch = 1 - - # Counterintiutively, having two loops here is actually faster. - # The first loop can run without waiting on GPU<->CPU sync. - for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type][:, 0] - sampled_token_indices = categorized_sample_indices[sampling_type][:, 1] - num_tokens = len(sample_indices) - if num_tokens == 0: - continue - seq_group_id = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] - sample_metadata[sampling_type] = (seq_group_id, seq_groups, - sample_indices, - sampled_token_indices) - if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, - SamplingType.RANDOM_SEED): - for seq_group in seq_groups: - if seq_group.is_prompt: - sampling_params = seq_group.sampling_params - max_best_of_in_batch = max(max_best_of_in_batch, - sampling_params.best_of) - elif sampling_type == SamplingType.BEAM: - beam_search_logprobs = logprobs[sample_indices] - else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") - - sampled_tokens, _, _ = sample_triton( - probs=probs, - seeds=sampling_tensors.sampling_seeds, - max_best_of=max_best_of_in_batch, - sample_indices=sampling_tensors.sample_indices, - logprobs=logprobs, - # don't save logprobs because we have logic for that below - # TODO: use this instead of the CPU-based logic below - save_logprobs=False, - ) - - # GPU<->CPU sync happens in the loop below. - - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups, sample_indices, - sampled_token_indices) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample( - seq_groups, sampled_tokens[sampled_token_indices][:, 0]) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample( - seq_groups, sampled_tokens[sampled_token_indices]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] - return sample_results - - def _sample( probs: torch.Tensor, logprobs: torch.Tensor, @@ -974,10 +887,6 @@ def _sample( modify_greedy_probs=modify_greedy_probs, ) - # TODO: Enable once Triton kernel & associated code is faster. - # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, - # sampling_tensors) - def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index a085779bc61a7..97d36d31f2b11 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,4 +1,3 @@ -import random from array import array from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -8,15 +7,10 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata) -from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad, - maybe_expand_dim) + is_pin_memory_available, make_tensor_with_pad) _SAMPLING_EPS = 1e-5 -_SEED_0_REPLACEMENT = 3403598558 -# Some triton sampler related code is guarded before it is ready. -_USE_TRITON_SAMPLER = False @dataclass @@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int): generator=None, is_prompt=True, prompt_logprob_indices=[], - sample_indices=[]) + sample_indices=[], + ) class SamplingMetadataCache: - """Used to cache SamplingMetadata objects between scheduler iterations - """ + """Used to cache SamplingMetadata objects between scheduler iterations""" def __init__(self): self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {} @@ -124,12 +118,12 @@ def sample(logits): The first tuple is [1, 2] (sampled index within original logit), and the second tuple is [0, 1] (sampled index within pruned logit). num_prompts: Number of prompt sequence groups in seq_groups. - skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU + skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU serialization of token outputs. - reuse_sampling_tensors: Indicates if we want to reuse sampling + reuse_sampling_tensors: Indicates if we want to reuse sampling tensors that are part of the sampler forward pass. Currently, it is mainly used for multi-step decode. - + """ def __init__( @@ -165,16 +159,19 @@ def prepare( num_prompts, ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, device, generators, cache) - selected_token_indices = async_tensor_h2d(selected_token_indices, - dtype=torch.long, - target_device=device, - pin_memory=pin_memory) + selected_token_indices = async_tensor_h2d( + selected_token_indices, + dtype=torch.long, + target_device=device, + pin_memory=pin_memory, + ) categorized_sample_indices = { - t: maybe_expand_dim( - async_tensor_h2d(seq_ids, - dtype=torch.int, - target_device=device, - pin_memory=pin_memory), 2, 2) + t: async_tensor_h2d( + seq_ids, + dtype=torch.int, + target_device=device, + pin_memory=pin_memory, + ) for t, seq_ids in categorized_sample_indices.items() } @@ -201,8 +198,8 @@ def _prepare_seq_groups( device: str, generators: Optional[Dict[str, torch.Generator]] = None, cache: Optional[SamplingMetadataCache] = None, -) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ - SamplingType, List[Tuple[int, int]]], int]: +) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType, + List[int]], int, ]: """Prepare sequence groups and indices for sampling. Args: @@ -233,16 +230,13 @@ def _prepare_seq_groups( # Sampling type -> ( # indices to sample/prompt logprob within pruned output logits, # indices to sample within pruned logits) - categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = { + categorized_sample_indices: Dict[SamplingType, List[int]] = { t: [] for t in SamplingType } # Index of logits to compute logprob. Logits include both prompt logprob # and sample logprob indices. logit_idx = 0 - # Index to sample from a sample tensor. It is used by triton sample kernel. - # See `_sample_with_triton_kernel` for more details. - sample_idx = 0 # Total number of prompts from given sequence groups. num_prompts = 0 @@ -264,10 +258,10 @@ def _prepare_seq_groups( # If the current seq group is in decode stage, it is None. seq_len: Optional[int] = None query_len: Optional[int] = None - prompt_logprob_indices: List[int] = \ - sample_obj.prompt_logprob_indices if cache is not None else [] - sample_indices: List[int] = \ - sample_obj.sample_indices if cache is not None else [] + prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices + if cache is not None else []) + sample_indices: List[int] = (sample_obj.sample_indices + if cache is not None else []) do_sample = seq_group_metadata.do_sample if seq_group_metadata.is_prompt: @@ -333,11 +327,8 @@ def sample(logits): if do_sample: sample_indices.extend(range(logit_idx, logit_idx + sample_len)) categorized_sample_indices[sampling_params.sampling_type].extend( - list( - zip(range(logit_idx, logit_idx + sample_len), - range(sample_idx, sample_idx + sample_len)))) + list(range(logit_idx, logit_idx + sample_len))) logit_idx += sample_len - sample_idx += sample_len if cache is not None: sample_obj.sampling_params = sampling_params @@ -356,7 +347,8 @@ def sample(logits): generator=generator, is_prompt=is_prompt, prompt_logprob_indices=list(prompt_logprob_indices), - sample_indices=list(sample_indices)) + sample_indices=list(sample_indices), + ) seq_groups.append(sample_obj) @@ -378,9 +370,6 @@ class SamplingTensors: presence_penalties: torch.Tensor frequency_penalties: torch.Tensor repetition_penalties: torch.Tensor - sampling_seeds: torch.Tensor - sample_indices: torch.Tensor - extra_seeds: Optional[torch.Tensor] prompt_tokens: torch.Tensor output_tokens: torch.Tensor @@ -391,15 +380,7 @@ def from_sampling_metadata( vocab_size: int, device: torch.device, dtype: torch.dtype, - *, - extra_seeds_to_generate: int = 0, - extra_entropy: Optional[Tuple[int, ...]] = None ) -> Tuple["SamplingTensors", bool, bool, bool]: - """ - extra_seeds_to_generate: extra seeds to generate using the - user-defined seed for each sequence. - extra_entropy: extra entropy to use when generating seeds. - """ prompt_tokens: List[array] = [] output_tokens: List[array] = [] top_ks: List[int] = [] @@ -409,19 +390,10 @@ def from_sampling_metadata( presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] - sampling_seeds: List[int] = [] - sample_indices: List[int] = [] do_penalties = False do_top_p_top_k = False do_min_p = False - if _USE_TRITON_SAMPLER: - prompt_best_of: List[int] = [] - - # We need one base seed per Triton slice. - seeds_to_generate = (extra_seeds_to_generate + - get_num_triton_sampler_splits(vocab_size)) - assert sampling_metadata.seq_groups is not None for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -452,7 +424,7 @@ def from_sampling_metadata( do_penalties = True is_prompt = seq_group.is_prompt - if (is_prompt and sampling_params.prompt_logprobs is not None): + if is_prompt and sampling_params.prompt_logprobs is not None: # For tokens in the prompt that we only need to get # their logprobs query_len = seq_group.query_len @@ -477,28 +449,6 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) - if _USE_TRITON_SAMPLER: - if is_prompt: - prompt_best_of.append(sampling_params.best_of) - query_len = seq_group.query_len - assert query_len is not None - - seed = sampling_params.seed - is_greedy = sampling_params.sampling_type == SamplingType.GREEDY - - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - extra_entropy = extra_entropy or () - seq_seeds = cls._get_sequence_seeds( - seed, - seq_data.get_len(), - *extra_entropy, - seq_id, - seeds_to_generate=seeds_to_generate, - is_greedy=is_greedy) - sampling_seeds.append(seq_seeds) - sample_indices.extend(seq_group.sample_indices) - if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -518,23 +468,37 @@ def from_sampling_metadata( output_tokens.append(seq_data.output_token_ids_array) sampling_tensors = SamplingTensors.from_lists( - temperatures, top_ps, top_ks, min_ps, presence_penalties, - frequency_penalties, repetition_penalties, sampling_seeds, - sample_indices, prompt_tokens, output_tokens, vocab_size, - extra_seeds_to_generate, device, dtype) + temperatures, + top_ps, + top_ks, + min_ps, + presence_penalties, + frequency_penalties, + repetition_penalties, + prompt_tokens, + output_tokens, + vocab_size, + device, + dtype, + ) return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) @classmethod - def from_lists(cls, temperatures: List[float], top_ps: List[float], - top_ks: List[int], min_ps: List[float], - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float], - sampling_seeds: List[int], sample_indices: List[int], - prompt_tokens: List[array], output_tokens: List[array], - vocab_size: int, extra_seeds_to_generate: int, - device: torch.device, - dtype: torch.dtype) -> "SamplingTensors": + def from_lists( + cls, + temperatures: List[float], + top_ps: List[float], + top_ks: List[int], + min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_tokens: List[array], + output_tokens: List[array], + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. pin_memory = is_pin_memory_available() @@ -603,34 +567,9 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.int, pin_memory=pin_memory, ) - sample_indices_t = torch.tensor( - sample_indices, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ) - # need to transpose and make contiguous to - # copy the tensor correctly. - # [batch_size, n_seeds] -> [n_seeds, batch_size] - sampling_seeds_t = torch.tensor( - sampling_seeds, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ).t().contiguous() - # Because the memory is pinned, we can do non-blocking # transfer to device. - # How many seeds the sample operation itself will need. - num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate - sampling_seeds_gpu = sampling_seeds_t.to(device=device, - non_blocking=True) - extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:] - if not extra_seeds_gpu.numel(): - extra_seeds_gpu = None - sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] - return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), @@ -644,38 +583,4 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], non_blocking=True), prompt_tokens=prompt_t.to(device=device, non_blocking=True), output_tokens=output_t.to(device=device, non_blocking=True), - sampling_seeds=sampling_seeds_gpu, - sample_indices=sample_indices_t.to(device=device, - non_blocking=True), - extra_seeds=extra_seeds_gpu, ) - - @staticmethod - def _get_sequence_seeds( - seed: int, - *extra_entropy: int, - seeds_to_generate: int, - is_greedy: bool, - ): - """Get `seeds_to_generate` child seeds from `seed` and extra entropy.""" - if not is_greedy: - if seed is None: - randint_fn = random.randint - else: - generator = random.Random(str((seed, ) + extra_entropy)) - randint_fn = generator.randint - lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max - # If the user/random sets seed = 0 but request should - # have sampling, we need to change it to something - # else. We use a constant in that case. - # This way we don't need to create and load a bool - # matrix in the sampling kernel, which reduces CPU - # overhead and latency. - seq_seeds = [ - randint_fn(lo, hi) or _SEED_0_REPLACEMENT - for _ in range(seeds_to_generate) - ] - else: - # For the kernel, seed == 0 means greedy decoding. - seq_seeds = [0] * seeds_to_generate - return seq_seeds diff --git a/vllm/triton_utils/sample.py b/vllm/triton_utils/sample.py deleted file mode 100644 index 401e4d28a3c99..0000000000000 --- a/vllm/triton_utils/sample.py +++ /dev/null @@ -1,13 +0,0 @@ -import math - -# This is a hardcoded limit in Triton (max block size). -MAX_TRITON_N_COLS = 131072 - - -def get_num_triton_sampler_splits(n_cols: int) -> int: - """Get the number of splits to use for Triton sampling. - - Triton has a limit on the number of columns it can handle, so we need to - split the tensor and call the kernel multiple times if it's too large. - """ - return math.ceil(n_cols / MAX_TRITON_N_COLS) diff --git a/vllm/utils.py b/vllm/utils.py index 014fc16a17c1f..1cbd9d55c68b3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -270,7 +270,7 @@ def clear(self): class PyObjectCache: - """Used to cache python objects to avoid object allocations + """Used to cache python objects to avoid object allocations across scheduler iterations. """ @@ -289,7 +289,7 @@ def _grow_cache(self): self._obj_cache.append(self._obj_builder()) def get_object(self): - """Returns a pre-allocated cached object. If there is not enough + """Returns a pre-allocated cached object. If there is not enough objects, then the cache size will double. """ if self._index >= len(self._obj_cache): @@ -837,15 +837,6 @@ def async_tensor_h2d( return t.to(device=target_device, non_blocking=True) -def maybe_expand_dim(tensor: torch.Tensor, - target_dims: int, - size: int = 1) -> torch.Tensor: - """Expand the tensor to the target_dims.""" - if tensor.ndim < target_dims: - tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) - return tensor - - def get_dtype_size(dtype: torch.dtype) -> int: """Get the size of the data type in bytes.""" return torch.tensor([], dtype=dtype).element_size() @@ -1070,7 +1061,7 @@ def _cuda_device_count_stateless( def cuda_device_count_stateless() -> int: """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES at the time of call. - + This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" @@ -1136,10 +1127,10 @@ def parse_args(self, args=None, namespace=None): def _pull_args_from_config(args: List[str]) -> List[str]: """Method to pull arguments specified in the config file into the command-line args variable. - - The arguments in config file will be inserted between + + The arguments in config file will be inserted between the argument list. - + example: ```yaml port: 12323 @@ -1150,21 +1141,21 @@ def _pull_args_from_config(args: List[str]) -> List[str]: --config config.yaml -tp 2 $: args = [ "serve,chat,complete", - "facebook/opt-12B", - '--config', 'config.yaml', + "facebook/opt-12B", + '--config', 'config.yaml', '-tp', '2' ] $: args = [ "serve,chat,complete", - "facebook/opt-12B", - '--port', '12323', - '--tensor-parallel-size', '4', + "facebook/opt-12B", + '--port', '12323', + '--tensor-parallel-size', '4', '-tp', '2' ] ``` Please note how the config args are inserted after the sub command. - this way the order of priorities is maintained when these are args + this way the order of priorities is maintained when these are args parsed by super(). """ assert args.count( @@ -1190,7 +1181,7 @@ def _pull_args_from_config(args: List[str]) -> List[str]: @staticmethod def _load_config_file(file_path: str) -> List[str]: - """Loads a yaml file and returns the key value pairs as a + """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml port: 12323 @@ -1201,7 +1192,7 @@ def _load_config_file(file_path: str) -> List[str]: '--port': '12323', '--tensor-parallel-size': '4' ] - + """ extension: str = file_path.split('.')[-1] From 1c1bb388e0d35a2d10da5c5cda2edac57bf62591 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 16 Sep 2024 22:17:32 -0600 Subject: [PATCH 012/116] [Frontend] Improve Nullable kv Arg Parsing (#8525) Signed-off-by: Alex-Brooks --- tests/engine/test_arg_utils.py | 20 +++++++++++++++++++- vllm/engine/arg_utils.py | 28 +++++++++++++++++++++------- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 3208d6bb48bdc..8dd200b35d0f3 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -1,6 +1,8 @@ +from argparse import ArgumentTypeError + import pytest -from vllm.engine.arg_utils import EngineArgs +from vllm.engine.arg_utils import EngineArgs, nullable_kvs from vllm.utils import FlexibleArgumentParser @@ -13,6 +15,10 @@ "image": 16, "video": 2 }), + ("Image=16, Video=2", { + "image": 16, + "video": 2 + }), ]) def test_limit_mm_per_prompt_parser(arg, expected): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) @@ -22,3 +28,15 @@ def test_limit_mm_per_prompt_parser(arg, expected): args = parser.parse_args(["--limit-mm-per-prompt", arg]) assert args.limit_mm_per_prompt == expected + + +@pytest.mark.parametrize( + ("arg"), + [ + "image", # Missing = + "image=4,image=5", # Conflicting values + "image=video=4" # Too many = in tokenized arg + ]) +def test_bad_nullable_kvs(arg): + with pytest.raises(ArgumentTypeError): + nullable_kvs(arg) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b5eba9ca3727a..35013eedea9c6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -44,22 +44,36 @@ def nullable_str(val: str): def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: + """Parses a string containing comma separate key [str] to value [int] + pairs into a dictionary. + + Args: + val: String value to be parsed. + + Returns: + Dictionary with parsed values. + """ if len(val) == 0: return None out_dict: Dict[str, int] = {} for item in val.split(","): - try: - key, value = item.split("=") - except TypeError as exc: - msg = "Each item should be in the form KEY=VALUE" - raise ValueError(msg) from exc + kv_parts = [part.lower().strip() for part in item.split("=")] + if len(kv_parts) != 2: + raise argparse.ArgumentTypeError( + "Each item should be in the form KEY=VALUE") + key, value = kv_parts try: - out_dict[key] = int(value) + parsed_value = int(value) except ValueError as exc: msg = f"Failed to parse value of item {key}={value}" - raise ValueError(msg) from exc + raise argparse.ArgumentTypeError(msg) from exc + + if key in out_dict and out_dict[key] != parsed_value: + raise argparse.ArgumentTypeError( + f"Conflicting values specified for key: {key}") + out_dict[key] = parsed_value return out_dict From ee2bceaaa67bd2f420f62a924da5834a7c1c862b Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 16 Sep 2024 22:22:45 -0700 Subject: [PATCH 013/116] [Misc][Bugfix] Disable guided decoding for mistral tokenizer (#8521) --- .../guided_decoding/__init__.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 7161e83952a3d..f4fe8a7307c04 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -6,6 +6,7 @@ from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest) from vllm.sampling_params import LogitsProcessor +from vllm.transformers_utils.tokenizer import MistralTokenizer async def get_guided_decoding_logits_processor( @@ -15,12 +16,23 @@ async def get_guided_decoding_logits_processor( request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'outlines' is currently not supported " + "for Mistral tokenizer. Please consider contributing to the " + "'outlines' project if you are interested in this feature.") # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( request, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'lm-format-enforcer' is currently not " + "supported for Mistral tokenizer. Please consider contributing " + "to the 'lm-format-enforcer' project if you are interested " + "in this feature.") from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_lm_format_enforcer_guided_decoding_logits_processor) return await get_lm_format_enforcer_guided_decoding_logits_processor( @@ -37,12 +49,23 @@ def get_local_guided_decoding_logits_processor( # request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'outlines' is currently not supported " + "for Mistral tokenizer. Please consider contributing to the " + "'outlines' project if you are interested in this feature.") # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_options, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'lm-format-enforcer' is currently not " + "supported for Mistral tokenizer. Please consider contributing " + "to the 'lm-format-enforcer' project if you are interested " + "in this feature.") from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( From 99aa4eddaf929f57dac405b00db3f5286624ee8b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 16 Sep 2024 22:57:57 -0700 Subject: [PATCH 014/116] [torch.compile] register allreduce operations as custom ops (#8526) --- .buildkite/test-pipeline.yaml | 10 +- csrc/custom_all_reduce.cu | 12 -- csrc/ops.h | 2 - csrc/torch_bindings.cpp | 5 - tests/compile/__init__.py | 0 tests/compile/test_full_graph.py | 15 ++- vllm/_custom_ops.py | 6 - .../device_communicators/custom_all_reduce.py | 21 +++- vllm/distributed/parallel_state.py | 116 +++++++++++++++--- 9 files changed, 137 insertions(+), 50 deletions(-) create mode 100644 tests/compile/__init__.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9b0cb6663a55b..9483adcc5d587 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -163,13 +163,6 @@ steps: - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py -- label: torch compile integration test - source_file_dependencies: - - vllm/ - commands: - - pytest -v -s ./compile/test_full_graph.py - - pytest -v -s ./compile/test_wrapper.py - - label: Prefix Caching Test # 7min #mirror_hardwares: [amd] source_file_dependencies: @@ -348,7 +341,10 @@ steps: - vllm/executor/ - vllm/model_executor/models/ - tests/distributed/ + - vllm/compilation commands: + - pytest -v -s ./compile/test_full_graph.py + - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus # Avoid importing model tests that cause CUDA reinitialization error diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 82a3563979f16..9b82bec44c3c6 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink) { - auto inp_size = inp.numel() * inp.element_size(); - // custom allreduce requires input byte size to be multiples of 16 - if (inp_size % 16 != 0) return false; - if (!_is_weak_contiguous(inp)) return false; - if (world_size == 2 || full_nvlink) return inp_size <= max_size; - // for 4 or more non NVLink-capable GPUs, custom allreduce provides little - // performance improvement over NCCL. - return false; -} - void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, cudaStream_t stream) { auto fa = reinterpret_cast(_fa); diff --git a/csrc/ops.h b/csrc/ops.h index 681ab4b898ca3..ee89ad32cb025 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -241,8 +241,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, const std::vector& offsets, int64_t rank, bool full_nvlink); -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d7f7547fbef55..7009180a8687c 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -411,11 +411,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { "bool full_nvlink) -> int"); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); - custom_ar.def( - "should_custom_ar(Tensor inp, int max_size, int world_size, " - "bool full_nvlink) -> bool"); - custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar); - custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); diff --git a/tests/compile/__init__.py b/tests/compile/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 5452ce6be8110..6fc445539bbbe 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -2,9 +2,20 @@ import pytest +from vllm.utils import cuda_device_count_stateless + +from ..utils import fork_new_process_for_each_test + @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -def test_full_graph(model): +@pytest.mark.parametrize("tp_size", [1, 2]) +@fork_new_process_for_each_test +def test_full_graph(model, tp_size): + + # Skip the test if there are not enough CUDA devices. + if cuda_device_count_stateless() < tp_size: + pytest.skip("Not enough CUDA devices for the test.") + # make sure these models can be captured in full graph mode if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" @@ -17,7 +28,7 @@ def test_full_graph(model): "The future of AI is", ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model=model, enforce_eager=True) + llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size) outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d5b3d7bc6dd5a..ac90895b11c37 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -870,12 +870,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, offsets, rank, full_nvlink) -def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int, - full_nvlink: bool) -> bool: - return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size, - full_nvlink) - - def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 6229f1d6ec788..d239d645edc14 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool: return True +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] @@ -224,8 +230,19 @@ def register_graph_buffers(self): ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): - return ops.should_custom_ar(inp, self.max_size, self.world_size, - self.full_nvlink) + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False # all reduce, assuming inp tensor is IPC registered with register_buffer, # or, in the context of cuda graphs, register_graph_buffers diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6755b20eec9bb..1c864bcd5d708 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -21,11 +21,12 @@ """ import contextlib import pickle +import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch import torch @@ -69,6 +70,58 @@ def _split_tensor_dict( return metadata_list, tensor_list +_group_name_counter: Dict[str, int] = {} + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + # looks like Python 3.8 does not understand `ReferenceType` + _groups[group.unique_name] = weakref.ref(group) # type: ignore + + +@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"]) +def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_reduce(tensor) + + +@inplace_all_reduce.register_fake +def _(tensor: torch.Tensor, group_name: str) -> None: + return + + +@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) +def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce(tensor) + + +@outplace_all_reduce.register_fake +def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) + + class GroupCoordinator: """ PyTorch ProcessGroup wrapper for a group of processes. @@ -111,7 +164,11 @@ def __init__( use_custom_allreduce: bool, use_tpu_communicator: bool, use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) self.rank = torch.distributed.get_rank() self.local_rank = local_rank @@ -149,28 +206,24 @@ def __init__( from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) - self.pynccl_comm: Optional[PyNcclCommunicator] + self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, ) - else: - self.pynccl_comm = None - self.ca_comm: Optional[CustomAllreduce] + self.ca_comm: Optional[CustomAllreduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, ) - else: - self.ca_comm = None from vllm.distributed.device_communicators.tpu_communicator import ( TpuCommunicator) - self.tpu_communicator: Optional[TpuCommunicator] + self.tpu_communicator: Optional[TpuCommunicator] = None if use_tpu_communicator and self.world_size > 1: self.tpu_communicator = TpuCommunicator(group=self.cpu_group) @@ -264,16 +317,46 @@ def graph_capture( def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """ + User-facing all-reduce function before we actually call the + all-reduce operation. + + We need this because Dynamo does not support passing an arbitrary + object (`self` in this case) to a custom op. We need to pass the + group name as a string, and then look up the group coordinator from + the group name, dispatch the all-reduce operation to the group + coordinator. + + In addition, PyTorch custom ops do not support mutation or returning + a new tensor in the same op. So we need to figure out if the op is + in-place or out-of-place ahead of time. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + if self.tpu_communicator is not None and \ + not self.tpu_communicator.disabled: + # TPU handles Dynamo with its own logic. + return self._all_reduce(input_) + + if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_): + return torch.ops.vllm.outplace_all_reduce( + input_, group_name=self.unique_name) + else: + torch.ops.vllm.inplace_all_reduce(input_, + group_name=self.unique_name) + return input_ + + def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + The actual all-reduce implementation. + NOTE: This operation will be applied in-place or out-of-place. Always assume this function modifies its input, but use the return value as the output. """ ca_comm = self.ca_comm - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - # For TPUs, use TPU communicator. tpu_comm = self.tpu_communicator if tpu_comm is not None and not tpu_comm.disabled: @@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int, use_pynccl=False, use_custom_allreduce=False, use_tpu_communicator=False, + group_name="world", ) @@ -767,6 +851,7 @@ def init_model_parallel_group( backend: str, use_custom_allreduce: Optional[bool] = None, use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE @@ -778,6 +863,7 @@ def init_model_parallel_group( use_custom_allreduce=use_custom_allreduce, use_tpu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, ) @@ -931,7 +1017,8 @@ def initialize_model_parallel( _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, - use_message_queue_broadcaster=True) + use_message_queue_broadcaster=True, + group_name="tp") # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // @@ -947,7 +1034,8 @@ def initialize_model_parallel( _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, - use_custom_allreduce=False) + use_custom_allreduce=False, + group_name="pp") def ensure_model_parallel_initialized( From cbdb25225914a04d94e8830f4e739faca8ff3b9d Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Tue, 17 Sep 2024 00:06:26 -0700 Subject: [PATCH 015/116] [Misc] Limit to ray[adag] 2.35 to avoid backward incompatible change (#8509) Signed-off-by: Rui Qiao --- requirements-test.txt | 2 +- vllm/executor/ray_gpu_executor.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 16a883b81ce50..10d463de27be5 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -14,7 +14,7 @@ librosa # required for audio test opencv-python # required for video test peft requests -ray[adag]>=2.35 +ray[adag]==2.35 sentence-transformers # required for embedding soundfile # required for audio test compressed-tensors==0.4.0 # required for compressed-tensors diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index b124fe2e08ea6..9433dce842b09 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -437,8 +437,10 @@ def _check_ray_adag_installation(self): required_version = version.parse("2.35") current_version = version.parse( pkg_resources.get_distribution("ray").version) - if current_version < required_version: - raise ValueError(f"Ray version {required_version} or greater is " + # TODO: update the constraint once we adapt to the backward + # incompatible API change from ray 2.36 + if current_version != required_version: + raise ValueError(f"Ray version {required_version} is " f"required, but found {current_version}") import importlib.util From 1b6de8352b878348974b3f117cbb68ed18daa609 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 17 Sep 2024 15:34:27 +0800 Subject: [PATCH 016/116] [Benchmark] Support sample from HF datasets and image input for benchmark_serving (#8495) --- benchmarks/backend_request_func.py | 6 +- benchmarks/benchmark_serving.py | 239 +++++++++++++++++++++-------- 2 files changed, 177 insertions(+), 68 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 3243bb94f787c..3def4a6d67acf 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -25,6 +25,7 @@ class RequestFuncInput: best_of: int = 1 use_beam_search: bool = False logprobs: Optional[int] = None + multi_modal_content: Optional[dict] = None @dataclass @@ -312,12 +313,15 @@ async def async_request_openai_chat_completions( async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) payload = { "model": request_func_input.model, "messages": [ { "role": "user", - "content": request_func_input.prompt, + "content": content }, ], "temperature": 0.0, diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 9ba3f649810b7..3ace910a6cac6 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -24,6 +24,8 @@ """ import argparse import asyncio +import base64 +import io import json import os import random @@ -31,11 +33,13 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple import numpy as np from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput) +from datasets import load_dataset +from PIL.Image import Image from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase @@ -84,7 +88,7 @@ def sample_sharegpt_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, -) -> List[Tuple[str, int, int]]: +) -> List[Tuple[str, int, int, None]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") # Load the dataset. @@ -119,7 +123,7 @@ def sample_sharegpt_requests( if prompt_len > 1024 or prompt_len + output_len > 2048: # Prune too long sequences. continue - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_dataset.append((prompt, prompt_len, output_len, None)) return filtered_dataset @@ -131,7 +135,7 @@ def sample_sonnet_requests( output_len: int, prefix_len: int, tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, str, int, int]]: +) -> List[Tuple[str, str, int, int, None]]: assert ( input_len > prefix_len ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." @@ -189,7 +193,65 @@ def sample_sonnet_requests( message, add_generation_prompt=True, tokenize=False) prompt_len = len(tokenizer(prompt_formatted).input_ids) sampled_requests.append( - (prompt, prompt_formatted, prompt_len, output_len)) + (prompt, prompt_formatted, prompt_len, output_len, None)) + + return sampled_requests + + +def sample_hf_requests( + dataset_path: str, + dataset_subset: str, + dataset_split: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: + dataset = load_dataset(dataset_path, + name=dataset_subset, + split=dataset_split, + streaming=True) + assert "conversations" in dataset.features, ( + "HF Dataset must have 'conversations' column.") + filtered_dataset = dataset.shuffle().filter( + lambda x: len(x["conversations"]) >= 2) + sampled_requests: List[Tuple[str, int, int, Dict[str, + Collection[str]]]] = [] + for data in filtered_dataset: + if len(sampled_requests) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = data["conversations"][0]["value"] + prompt_token_ids = tokenizer(prompt).input_ids + completion = data["conversations"][1]["value"] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + + if "image" in data and isinstance(data["image"], Image): + image: Image = data["image"] + image = image.convert("RGB") + image_data = io.BytesIO() + image.save(image_data, format='JPEG') + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + mm_content = { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + else: + mm_content = None + + sampled_requests.append((prompt, prompt_len, output_len, mm_content)) return sampled_requests @@ -223,8 +285,8 @@ def sample_random_requests( [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) - input_requests.append( - (prompt, int(prefix_len + input_lens[i]), int(output_lens[i]))) + input_requests.append((prompt, int(prefix_len + input_lens[i]), + int(output_lens[i]), None)) return input_requests @@ -343,7 +405,12 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0]) + if backend != "openai-chat" and test_mm_content is not None: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' backend.") test_input = RequestFuncInput( model=model_id, prompt=test_prompt, @@ -353,6 +420,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=test_mm_content, ) test_output = await request_func(request_func_input=test_input) if not test_output.success: @@ -373,6 +441,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=test_mm_content, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: @@ -385,7 +454,7 @@ async def benchmark( benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request + prompt, prompt_len, output_len, mm_content = request request_func_input = RequestFuncInput( model=model_id, prompt=prompt, @@ -395,6 +464,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=mm_content, ) tasks.append( asyncio.create_task( @@ -575,6 +645,16 @@ def main(args: argparse.Namespace): for prompt, prompt_formatted, prompt_len, output_len in input_requests] + elif args.dataset_name == "hf": + input_requests = sample_hf_requests( + dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.hf_output_len, + ) + elif args.dataset_name == "random": input_requests = sample_random_requests( prefix_len=args.random_prefix_len, @@ -685,13 +765,14 @@ def main(args: argparse.Namespace): "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "sonnet", "random"], + choices=["sharegpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) parser.add_argument("--dataset-path", type=str, default=None, - help="Path to the dataset.") + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.") parser.add_argument( "--model", type=str, @@ -718,26 +799,6 @@ def main(args: argparse.Namespace): default=1000, help="Number of prompts to process.", ) - parser.add_argument( - "--sharegpt-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.") - parser.add_argument( - "--sonnet-input-len", - type=int, - default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--sonnet-output-len", - type=int, - default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", - ) parser.add_argument( "--logprobs", type=int, @@ -748,42 +809,6 @@ def main(args: argparse.Namespace): "logprob is returned for each token; or (2) if beam search " "is enabled 1 logprob per token is computed"), ) - parser.add_argument( - "--sonnet-prefix-len", - type=int, - default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--random-input-len", - type=int, - default=1024, - help= - "Number of input tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-output-len", - type=int, - default=128, - help= - "Number of output tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-range-ratio", - type=float, - default=1.0, - help="Range of sampled ratio of input/output length, " - "used only for random sampling.", - ) - parser.add_argument( - "--random-prefix-len", - type=int, - default=0, - help="Number of fixed prefix tokens before random " - " context. The length range of context in a random " - " request is [random-prefix-len, " - " random-prefix-len + random-prefix-len * random-range-ratio).") parser.add_argument( "--request-rate", type=float, @@ -857,5 +882,85 @@ def main(args: argparse.Namespace): "Use \"--percentile-metrics\" to select metrics.", ) + # group for dataset specific arguments + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.") + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=1.0, + help="Range of sampled ratio of input/output length, " + "used only for random sampling.", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before random " + " context. The length range of context in a random " + " request is [random-prefix-len, " + " random-prefix-len + random-prefix-len * random-range-ratio).") + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + args = parser.parse_args() main(args) From 1009e93c5d634c724eeff3d4e453369337f502d4 Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Tue, 17 Sep 2024 07:35:01 -0700 Subject: [PATCH 017/116] [Encoder decoder] Add cuda graph support during decoding for encoder-decoder models (#7631) --- .buildkite/test-pipeline.yaml | 7 + tests/encoder_decoder/__init__.py | 0 tests/encoder_decoder/test_e2e_correctness.py | 98 ++++++++++ .../test_encoder_decoder_model_runner.py | 182 +++++++++++++++--- vllm/attention/backends/abstract.py | 17 +- vllm/attention/backends/flashinfer.py | 12 +- vllm/attention/backends/utils.py | 113 ++++++++++- vllm/config.py | 41 +--- vllm/engine/arg_utils.py | 5 +- vllm/entrypoints/llm.py | 8 +- vllm/model_executor/models/bart.py | 6 +- vllm/utils.py | 5 - vllm/worker/enc_dec_model_runner.py | 43 ++++- vllm/worker/model_runner.py | 97 ++++++++-- vllm/worker/utils.py | 4 - 15 files changed, 526 insertions(+), 112 deletions(-) create mode 100644 tests/encoder_decoder/__init__.py create mode 100644 tests/encoder_decoder/test_e2e_correctness.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9483adcc5d587..63ce9bff7d4c1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -252,6 +252,13 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: Encoder Decoder tests # 5min + source_file_dependencies: + - vllm/ + - tests/encoder_decoder + commands: + - pytest -v -s encoder_decoder + - label: OpenAI-Compatible Tool Use # 20 min fast_check: false mirror_hardwares: [ amd ] diff --git a/tests/encoder_decoder/__init__.py b/tests/encoder_decoder/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py new file mode 100644 index 0000000000000..9324a737a779c --- /dev/null +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -0,0 +1,98 @@ +"""E2E tests to verify the correctness of the encoder-decoder framework + +Run `pytest tests/encoder_decoder/test_e2e_correctness.py`. +""" +from typing import List, Optional, Tuple + +import pytest +from transformers import AutoModelForSeq2SeqLM + +from vllm.sequence import SampleLogprobs +from vllm.utils import is_cpu + +from ..conftest import DecoderPromptType +from ..models.utils import check_logprobs_close + + +def vllm_to_hf_output( + vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, +): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "" + if decoder_prompt_type == DecoderPromptType.NONE: + hf_output_str = "" + hf_output_str + + return output_ids, hf_output_str, out_logprobs + + +@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.skipif( + is_cpu(), + reason="CPU backend is not currently supported with encoder/decoder models" +) +def test_encoder_decoder_e2e( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, + decoder_prompt_type: DecoderPromptType, + enforce_eager: bool, +) -> None: + ''' + End-to-End (E2E) test for the encoder-decoder framework. + This test evaluates the encoder-decoder functionality using the BART + model. We compare the outputs of the Hugging Face and vLLM + implementations to ensure that both implementations produce consistent + and correct results. + ''' + test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type] + + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( + test_case_prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + with vllm_runner(model, dtype=dtype, + enforce_eager=enforce_eager) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + test_case_prompts, max_tokens, num_logprobs) + + hf_skip_tokens = (1 + if decoder_prompt_type == DecoderPromptType.NONE else 0) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 32bff22f66a8b..a00d46ddeb007 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -1,3 +1,4 @@ +import itertools from array import array from typing import List @@ -7,13 +8,9 @@ from vllm.engine.arg_utils import EngineArgs from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, SequenceData, SequenceGroupMetadata) -from vllm.utils import is_cpu +from vllm.utils import is_cpu, make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner - -# CUDA graph scenarios to test -# -# Currently CUDA graph is not supported -ENFORCE_EAGER = [True] +from vllm.worker.model_runner import _get_graph_batch_size BATCH_SIZES = [1, 4, 16, 64, 256] @@ -40,8 +37,7 @@ def _create_model_runner(model: str, *args, reason="CPU backend is currently " "unsupported for encoder/ " "decoder models") -@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) -def test_empty_seq_group(enforce_eager, ): +def test_empty_seq_group(): """Verify prepare prompt and decode returns empty output for empty seq group list""" @@ -52,7 +48,7 @@ def test_empty_seq_group(enforce_eager, ): max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, - enforce_eager=enforce_eager, + enforce_eager=True, ) seq_group_metadata_list: List[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( @@ -85,11 +81,7 @@ def test_empty_seq_group(enforce_eager, ): "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) -def test_prepare_prompt( - batch_size, - enforce_eager, -): +def test_prepare_prompt(batch_size): ''' Test the ability of the encoder/decoder model runner subclass to produce prefill-phase model inputs & attention metadata. @@ -115,7 +107,7 @@ def test_prepare_prompt( max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, - enforce_eager=enforce_eager, + enforce_eager=True, ) seq_lens: List[int] = [] @@ -281,11 +273,7 @@ def test_prepare_prompt( "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) -def test_prepare_decode( - batch_size, - enforce_eager, -): +def test_prepare_decode(batch_size): ''' Test the ability of the encoder/decoder model runner subclass to produce decode-phase model inputs & attention metadata. @@ -311,7 +299,7 @@ def test_prepare_decode( max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, - enforce_eager=enforce_eager, + enforce_eager=True, ) seq_lens: List[int] = [] @@ -428,7 +416,8 @@ def test_prepare_decode( expected, ) - # Cuda graph should is currently not supported for encoder/decoer. + # Model runner's CUDAGraph setting should be propagated to attention + # metadata. assert attn_metadata.use_cuda_graph is False # Verify the lengths of input tokens & positions @@ -484,3 +473,152 @@ def test_prepare_decode( dtype=actual.dtype, ) assert torch.equal(actual, expected) + + +@pytest.mark.parametrize("batch_size", list(range(1, 257))) +def test_prepare_decode_cuda_graph(batch_size): + """ + Tests that for encoder-decoder models with CUDA Graph capture and replay + enabled, the tensors used during the decode phase are correctly padded + for varying input batch sizes. + """ + model_runner = _create_model_runner( + "facebook/bart-base", + seed=0, + dtype="float16", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=False, + ) + + seq_lens: List[int] = [] + encoder_seq_lens: List[int] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + block_tables = {0: [1]} + cross_block_table = [2] + for i in range(batch_size): + # make sure all tokens fit into one block + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) + encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 + encoder_seq_lens.append(encoder_seq_len) + encoder_seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + + model_input = model_runner.prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = attn_metadata.slot_mapping + encoder_input_tokens = model_input.encoder_input_tokens + encoder_input_positions = model_input.encoder_input_positions + cross_slot_mapping = attn_metadata.cross_slot_mapping + + # With CUDA Graph capture and replay enabled, the decoder and encoder + # input sequences will be padded. Create the expected padded tensors + # accordingly. + graph_batch_size = _get_graph_batch_size(batch_size) + cuda_graph_pad_size = graph_batch_size - batch_size + padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) + padded_encoder_seq_lens = encoder_seq_lens + list( + itertools.repeat(1, cuda_graph_pad_size)) + + assert return_seq_lens == padded_seq_lens + assert len(slot_mapping) == len(input_tokens) + assert len(cross_slot_mapping) == len(encoder_input_tokens) + + # Verify attention metadata + device = model_runner.device + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_decode_tokens > 0 + assert torch.equal( + attn_metadata.seq_lens_tensor, + torch.tensor(padded_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == padded_seq_lens + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(seq_lens) + # - Encoder attention metadata + assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens + assert torch.equal( + attn_metadata.encoder_seq_lens_tensor, + torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens) + assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens) + + # Verify block tables are correct for prompts + # - Decoder self-attention. Pad the block tables as expected. + expected = [block_tables[0] for _ in range(batch_size)] + expected.extend([[] for _ in range(cuda_graph_pad_size)]) + expected = make_tensor_with_pad( + expected, + max_len=64, + pad=0, + dtype=torch.int32, + device=model_runner.device, + ) + assert torch.equal( + attn_metadata.block_tables, + expected, + ) + # - Encoder/decoder cross-attention. Pad the cross-attention block tables + # as expected. + expected = [cross_block_table for _ in range(len(seq_group_metadata_list))] + expected.extend([[] for _ in range(cuda_graph_pad_size)]) + expected = make_tensor_with_pad( + expected, + max_len=64, + pad=0, + dtype=torch.int32, + device=model_runner.device, + ) + assert torch.equal( + attn_metadata.cross_block_tables, + expected, + ) + + # Model runner's CUDAGraph setting should be propagated to attention + # metadata. + assert attn_metadata.use_cuda_graph is True + + # Verify the lengths of input tokens & positions + # - Decoder + assert len(input_tokens) == len(padded_seq_lens) + assert len(input_positions) == len(padded_seq_lens) + # -- An indirect check that model_input.input_tokens + # and model_input.input_positions are correct - + # by design of the test, the input tokens are + # equal to the input position values, so if + # the model_input data structure has the correct + # values then these two should be equal + assert torch.equal( + input_tokens, + input_positions, + ) + # - Encoder + assert len(encoder_input_tokens) == 0 + assert len(encoder_input_tokens) == 0 + # -- An indirect check that model_input.encoder_input_tokens + # and model_input.encoder_input_positions are correct - + # by design of the test, the input tokens are + # equal to the input position values, so if + # the model_input data structure has the correct + # values then these two should be equal + assert torch.equal( + encoder_input_tokens, + encoder_input_positions, + ) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index adc8390e6f9ec..2bc36ff18a96b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -156,18 +156,27 @@ def graph_clone(self, batch_size: int) -> "AttentionState[T]": ... @abstractmethod - def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T: + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: """Get attention metadata for CUDA graph capture of batch_size.""" ... @abstractmethod - def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]: + def get_graph_input_buffers( + self, + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: """Get attention-specific input buffers for CUDA graph capture.""" ... @abstractmethod - def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any], - attn_metadata: T) -> None: + def prepare_graph_input_buffers( + self, + input_buffers: Dict[str, Any], + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> None: """In-place modify input buffers dict for CUDA graph replay.""" ... diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 4054d337316fe..3a602fbfbbc04 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -172,7 +172,8 @@ def graph_clone(self, batch_size: int): state._prefill_wrapper = self._get_prefill_wrapper() return state - def graph_capture_get_metadata_for_batch(self, batch_size: int): + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): assert self._is_graph_capturing _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] @@ -232,12 +233,17 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): attn_metadata.begin_forward() return attn_metadata - def get_graph_input_buffers(self, attn_metadata): + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): return { "slot_mapping": attn_metadata.slot_mapping, } - def prepare_graph_input_buffers(self, input_buffers, attn_metadata): + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): return def begin_forward(self, model_input): diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 0375d3488eb15..089008967a244 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -304,7 +304,8 @@ def graph_clone(self, batch_size: int) -> "CommonAttentionState": assert self._is_graph_capturing return self.__class__(self.runner) - def graph_capture_get_metadata_for_batch(self, batch_size: int): + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): assert self._is_graph_capturing attn_metadata = self.runner.attn_backend.make_metadata( num_prefills=0, @@ -322,21 +323,121 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): block_tables=self._graph_block_tables[:batch_size], use_cuda_graph=True, ) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) + return attn_metadata - def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]: - return { + def get_graph_input_buffers( + self, + attn_metadata, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + input_buffers = { "slot_mapping": attn_metadata.slot_mapping, "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } - - def prepare_graph_input_buffers(self, input_buffers, - attn_metadata) -> None: + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) + return input_buffers + + def prepare_graph_input_buffers( + self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False) -> None: input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) def begin_forward(self, model_input) -> None: return + + def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, + attn_metadata): + """ + Updates the attention metadata parameters for CUDA graph capture in an + encoder-decoder model. + + This method modifies attention-related tensors and metadata required + for CUDA graph capture in encoder-decoder models. Specifically, it + updates the cross-attention and encoder sequence tensors in the + AttentionMetadata object. + """ + # During decode phase the cross_slot_mapping will be empty. Hence set + # an empty tensor for CUDA Graph capture. + attn_metadata.cross_slot_mapping = torch.tensor( + [], dtype=torch.int).cuda() + attn_metadata.cross_block_tables = torch.full( + (batch_size, self.runner.get_max_block_per_batch()), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens = torch.full((batch_size, ), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens_tensor = torch.full( + (batch_size, ), 1, dtype=torch.int).cuda() + attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + + def _add_additonal_input_buffers_for_enc_dec_model( + self, attn_metadata, input_buffers: Dict[str, Any]): + """ + Saves additional input buffers specific to the encoder-decoder model + from the attention metadata. + + This method extracts and stores encoder-decoder related input buffers + from the `attn_metadata` into the `input_buffers` dictionary. The + buffers include encoder sequence lengths, cross-slot mappings, and + cross-block tables, which are essential for the encoder-decoder model + during CUDA graph replay. + """ + input_buffers["encoder_seq_lens_tensor"] = ( + attn_metadata.decode_metadata.encoder_seq_lens_tensor) + input_buffers["cross_slot_mapping"] = ( + attn_metadata.decode_metadata.cross_slot_mapping) + input_buffers["cross_block_tables"] = ( + attn_metadata.decode_metadata.cross_block_tables) + + def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, + input_buffers: Dict[str, + Any]): + """ + Populates input buffers with data from the encoder-decoder model's + attention metadata. + + This method fills the input buffers with encoder-decoder specific + tensors. It copies data from the `attn_metadata` and keyword arguments + (`kwargs`) into corresponding buffers in the `input_buffers` dictionary. + The copied data includes attention-related metadata as well as input + IDs and positional information for the encoder. + """ + input_buffers["encoder_seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.encoder_seq_lens_tensor, + non_blocking=True) + input_buffers["cross_slot_mapping"].copy_( + attn_metadata.decode_metadata.cross_slot_mapping, + non_blocking=True) + input_buffers["cross_block_tables"].copy_( + attn_metadata.decode_metadata.cross_block_tables, + non_blocking=True) diff --git a/vllm/config.py b/vllm/config.py index 89cffc8b306b2..a0991597d0673 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -16,9 +16,8 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config) -from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, - cuda_device_count_stateless, get_cpu_memory, is_cpu, - is_hip, is_neuron, is_openvino, is_xpu, +from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, + is_cpu, is_hip, is_neuron, is_openvino, is_xpu, print_warning_once) if TYPE_CHECKING: @@ -96,15 +95,15 @@ class ModelConfig: enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. - If None, the user did not specify, so default to False - - except for encoder/decoder models, which currently require - eager mode. + If None, the user did not specify, so default to False. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. disable_sliding_window: Whether to disable sliding window. If True, we will disable the sliding window functionality of the model. If the model does not support sliding window, this argument is @@ -186,32 +185,8 @@ def __init__(self, self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc - # Choose a default enforce_eager value if the user did not specify - # a value (enforce_eager is None) - if getattr(self.hf_config, 'is_encoder_decoder', False): - if self.enforce_eager is None: - # *Only for encoder/decoder models* and - # *only if enforce_eager is unset*, override - # to enforce_eager=True - # - # Add a logger message since it is *somewhat* non-intuitive that - # enforce_eager is True when the user has not specified its - # value. - logger.info("Forcing enforce_eager == True because " - "enforce_eager setting was unspecified and " - "CUDAGraph is not supported with encoder/ " - "decoder models.") - self.enforce_eager = True - - if not self.enforce_eager: - # Eager mode explicitly disabled by user for an encoder/ - # decoder model; however CUDAGRAPH + encoder/decoder is - # not currently supported - raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH) - elif self.enforce_eager is None: - # *Only for decoder-only models*, enforce_eager - # defaults to False if unset. This is intuitive - # so no logging message needed. + # Set enforce_eager to False if the value is unset. + if self.enforce_eager is None: self.enforce_eager = False if (not self.disable_sliding_window diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 35013eedea9c6..4139eca9c1832 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -472,7 +472,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.max_seq_len_to_capture, help='Maximum sequence length covered by CUDA ' 'graphs. When a sequence has context length ' - 'larger than this, we fall back to eager mode.') + 'larger than this, we fall back to eager mode. ' + 'Additionally for encoder-decoder models, if the ' + 'sequence length of the encoder input is larger ' + 'than this, we fall back to the eager mode.') parser.add_argument('--disable-custom-all-reduce', action='store_true', default=EngineArgs.disable_custom_all_reduce, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c01bffeb4289d..a26b721093521 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -88,7 +88,9 @@ class LLM: to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead). max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode. + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. disable_custom_all_reduce: See ParallelConfig **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) @@ -137,9 +139,7 @@ def __init__( LLM constructor. Note: if enforce_eager is unset (enforce_eager is None) - it defaults to False for decoder-only models and True - for encoder/decoder models, since encoder/decoder models - do not currently support CUDAGraph. + it defaults to False. ''' if "disable_log_stats" not in kwargs: diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 9b4c4be7fcb09..cbdacf779b089 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -848,11 +848,13 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, ) -> torch.Tensor: r""" Args: diff --git a/vllm/utils.py b/vllm/utils.py index 1cbd9d55c68b3..29b8a8c2907eb 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -71,10 +71,6 @@ "currently supported with encoder/" "decoder models.") -STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not " - "currently supported with encoder/" - "decoder models.") - STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend " "currently supported with encoder/" "decoder models.") @@ -98,7 +94,6 @@ "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP, "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM, "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, - "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, "STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index d6189d82d51d9..09dab0135f390 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -1,4 +1,5 @@ import dataclasses +import itertools from typing import Any, Dict, List, Optional, Tuple, Type, cast import torch @@ -24,7 +25,8 @@ from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + ModelInputForGPUWithSamplingMetadata, + _get_graph_batch_size) from vllm.worker.model_runner_base import ( _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict) @@ -178,7 +180,15 @@ def execute_model( raise ValueError("num_steps > 1 is not supported in " "EncoderDecoderModelRunner") - model_executable = self.model + if (model_input.attn_metadata is not None + and model_input.attn_metadata.prefill_metadata is None + and model_input.attn_metadata.decode_metadata.use_cuda_graph): + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[ + model_input.virtual_engine][graph_batch_size] + else: + model_executable = self.model seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, @@ -200,6 +210,9 @@ def execute_model( if not self.is_driver_worker: return [] + if model_input.async_callback is not None: + model_input.async_callback() + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, @@ -231,14 +244,12 @@ def prepare_model_input( """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) - ( attn_metadata, encoder_input_tokens_tensor, encoder_input_positions_tensor, ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, model_input)) - # Inject attn_metadata encoder/cross-attention fields & # encoder input tokens/positions into model_input. # Frozen dataclass fields cannot be modified, so use @@ -437,11 +448,29 @@ def _prepare_encoder_model_input_tensors( cross_block_tables.append([] if ( cross_block_table is None) else cross_block_table) - # Convert cross-attention block tables to encoder input tensor + if (model_input.attn_metadata is not None + and model_input.attn_metadata.use_cuda_graph): + # We will be using CUDA graph replay for this decode. + max_len_of_block_table = self.get_max_block_per_batch() + batch_size = len(encoder_seq_lens) + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + cuda_graph_pad_size = graph_batch_size - batch_size + # extend the cross_block_tables and encoder_seq_lens to match + # the graph_batch_size. + cross_block_tables.extend([[] + for _ in range(cuda_graph_pad_size) + ]) + encoder_seq_lens.extend( + itertools.repeat(1, cuda_graph_pad_size)) + + else: + max_len_of_block_table = max( + len(block_table) for block_table in cross_block_tables) + cross_block_tables = make_tensor_with_pad( cross_block_tables, - max_len=max( - len(block_table) for block_table in cross_block_tables), + max_len=max_len_of_block_table, pad=0, dtype=torch.int32, device=self.device, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9df9ae783b9fa..e8c472df8b5fc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -243,6 +243,7 @@ def __init__( prefix_cache_hit: bool = False, reinit: bool = False, reinit_use_defaults: bool = False, + encoder_seq_len: int = 0, ): if reinit: assert len(self.seq_ids) == len(seq_ids) # type: ignore @@ -256,6 +257,7 @@ def __init__( self.block_tables = block_tables self.computed_block_nums = computed_block_nums self.n_seqs = n_seqs + self.encoder_seq_len = encoder_seq_len if reinit: if len(self.seq_ids) == 1 and reinit_use_defaults: @@ -702,6 +704,11 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): assert n_seqs == 1 self.decode_only = False + encoder_seq_len = 0 + + if self.runner.model_config.is_encoder_decoder_model: + encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() + inter_data = self.init_cached_inter_data( request_id=seq_group_metadata.request_id, seq_ids=seq_ids, @@ -709,7 +716,8 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): block_tables=seq_group_metadata.block_tables, computed_block_nums=seq_group_metadata.computed_block_nums, reinit=True, - reinit_use_defaults=True) + reinit_use_defaults=True, + encoder_seq_len=encoder_seq_len) self.inter_data_list.append(inter_data) @@ -719,11 +727,15 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): for per_seq_group_fn in self.per_seq_group_compute_fns: per_seq_group_fn(inter_data, seq_group_metadata) - def _use_captured_graph(self, batch_size: int, - max_decode_seq_len: int) -> bool: + def _use_captured_graph(self, + batch_size: int, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0) -> bool: return (self.decode_only and not self.runner.model_config.enforce_eager - and batch_size <= self.runner.max_batchsize_to_capture - and max_decode_seq_len <= self.runner.max_seq_len_to_capture) + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.runner.max_seq_len_to_capture + and max_encoder_seq_len <= self.runner.max_seq_len_to_capture + and batch_size <= self.runner.max_batchsize_to_capture) def build(self) -> ModelInputForGPU: """Finalize the builder intermediate data and @@ -763,15 +775,18 @@ def build(self) -> ModelInputForGPU: input_positions.extend(cur_input_positions) seq_lens = [] + query_lens = [] max_decode_seq_len = 0 + max_encoder_seq_len = 0 for inter_data in self.inter_data_list: seq_lens.extend(inter_data.seq_lens) + query_lens.extend(inter_data.query_lens) if not inter_data.is_prompt: max_decode_seq_len = max(max_decode_seq_len, max(inter_data.seq_lens)) - query_lens = [] - for inter_data in self.inter_data_list: - query_lens.extend(inter_data.query_lens) + if self.runner.model_config.is_encoder_decoder_model: + max_encoder_seq_len = max(max_encoder_seq_len, + inter_data.encoder_seq_len) # Mapping from request IDs to sequence IDs. Used for Jamba models # that manages the cache by itself. @@ -781,8 +796,10 @@ def build(self) -> ModelInputForGPU: } batch_size = len(input_tokens) - use_captured_graph = self._use_captured_graph(batch_size, - max_decode_seq_len) + use_captured_graph = self._use_captured_graph( + batch_size, + max_decode_seq_len, + max_encoder_seq_len=max_encoder_seq_len) # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. @@ -1364,7 +1381,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: for batch_size in reversed(batch_size_capture_list): attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( - batch_size)) + batch_size, + is_encoder_decoder_model=self.model_config. + is_encoder_decoder_model)) if self.lora_config: lora_mapping = LoRAMapping( @@ -1380,10 +1399,10 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: ) self.set_active_prompt_adapters( set(), prompt_adapter_mapping) - graph_runner = CUDAGraphRunner( self.model, self.attn_backend.get_name(), - self.attn_state.graph_clone(batch_size)) + self.attn_state.graph_clone(batch_size), + self.model_config.is_encoder_decoder_model) capture_inputs = { "input_ids": @@ -1420,6 +1439,12 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.model.get_seqlen_agnostic_capture_inputs( batch_size) }) + if self.model_config.is_encoder_decoder_model: + # add the additional inputs to capture for + # encoder-decoder models. + self._update_inputs_to_capture_for_enc_dec_model( + capture_inputs) + graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( @@ -1430,6 +1455,24 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # This usually takes < 10 seconds. logger.info("Graph capturing finished in %.0f secs.", elapsed_time) + def _update_inputs_to_capture_for_enc_dec_model(self, + capture_inputs: Dict[str, + Any]): + """ + Updates the set of input tensors needed for CUDA graph capture in an + encoder-decoder model. + + This method modifies the provided `capture_inputs` dictionary by + adding tensors specific to encoder-decoder specific models that + need to be captured for CUDA Graph replay. + """ + # During the decode phase encoder_input_ids and encoder_positions are + # unset. Do the same thing for graph capture. + capture_inputs["encoder_input_ids"] = torch.tensor( + [], dtype=torch.long).cuda() + capture_inputs["encoder_positions"] = torch.tensor( + [], dtype=torch.long).cuda() + @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() @@ -1629,7 +1672,7 @@ def execute_model( class CUDAGraphRunner: def __init__(self, model: nn.Module, backend_name: str, - attn_state: AttentionState): + attn_state: AttentionState, is_encoder_decoder_model: bool): self.model = model self.backend_name = backend_name self.attn_state = attn_state @@ -1638,6 +1681,7 @@ def __init__(self, model: nn.Module, backend_name: str, self.output_buffers: Dict[str, torch.Tensor] = {} self._graph: Optional[torch.cuda.CUDAGraph] = None + self._is_encoder_decoder_model = is_encoder_decoder_model @property def graph(self): @@ -1671,8 +1715,9 @@ def capture( intermediate_tensors=intermediate_inputs, **kwargs, ) + # Wait for the warm up operations to finish before proceeding with + # Graph Capture. torch.cuda.synchronize() - # Capture the graph. self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): @@ -1704,10 +1749,14 @@ def capture( # Save the input and output buffers. self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - **self.attn_state.get_graph_input_buffers(attn_metadata), + "input_ids": + input_ids, + "positions": + positions, + "kv_caches": + kv_caches, + **self.attn_state.get_graph_input_buffers( + attn_metadata, self._is_encoder_decoder_model), **kwargs, } if intermediate_inputs is not None: @@ -1737,8 +1786,8 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.attn_state.prepare_graph_input_buffers(self.input_buffers, - attn_metadata) + self.attn_state.prepare_graph_input_buffers( + self.input_buffers, attn_metadata, self._is_encoder_decoder_model) if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) @@ -1752,6 +1801,12 @@ def forward( if key != "model_execute_time" and key != "model_forward_time": self.input_buffers[key].copy_(intermediate_tensors[key], non_blocking=True) + if self._is_encoder_decoder_model: + self.input_buffers["encoder_input_ids"].copy_( + kwargs['encoder_input_ids'], non_blocking=True) + self.input_buffers["encoder_positions"].copy_( + kwargs['encoder_positions'], non_blocking=True) + # Run the graph. self.graph.replay() # Return the output tensor. diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index d73023e8e1724..a58b80e4f2adb 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -47,10 +47,6 @@ def assert_enc_dec_mr_supported_scenario( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) - if not enc_dec_mr.model_config.enforce_eager: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH']) - if enc_dec_mr.prompt_adapter_config is not None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ 'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER']) From 9855b99502c7537db5ef018129e603650800ac46 Mon Sep 17 00:00:00 2001 From: chenqianfzh <51831990+chenqianfzh@users.noreply.github.com> Date: Tue, 17 Sep 2024 08:09:12 -0700 Subject: [PATCH 018/116] [Feature][kernel] tensor parallelism with bitsandbytes quantization (#8434) --- tests/quantization/test_bitsandbytes.py | 26 ++++++++++--- vllm/config.py | 6 --- vllm/model_executor/layers/linear.py | 21 ++++++++--- vllm/model_executor/model_loader/loader.py | 44 +++++++++++++++++++++- 4 files changed, 80 insertions(+), 17 deletions(-) diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 87200b1dcc534..36167cf95f589 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -64,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name) +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason='Test requires at least 2 GPUs.') +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) +@fork_new_process_for_each_test +def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + hf_model_kwargs = {"load_in_4bit": True} + validate_generated_texts(hf_runner, + vllm_runner, + example_prompts[:1], + model_name, + hf_model_kwargs, + vllm_tp_size=2) + + def log_generated_texts(prompts, outputs, runner_name): logged_texts = [] for i, (_, generated_text) in enumerate(outputs): @@ -80,22 +98,21 @@ def validate_generated_texts(hf_runner, vllm_runner, prompts, model_name, - hf_model_kwargs=None): + hf_model_kwargs=None, + vllm_tp_size=1): # NOTE: run vLLM first, as it requires a clean process # when using distributed inference - - #Run with vLLM runner with vllm_runner(model_name, quantization='bitsandbytes', load_format='bitsandbytes', + tensor_parallel_size=vllm_tp_size, enforce_eager=True, gpu_memory_utilization=0.8) as llm: vllm_outputs = llm.generate_greedy(prompts, 8) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") # Clean up the GPU memory for the next test - torch.cuda.synchronize() gc.collect() torch.cuda.empty_cache() @@ -108,7 +125,6 @@ def validate_generated_texts(hf_runner, hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") # Clean up the GPU memory for the next test - torch.cuda.synchronize() gc.collect() torch.cuda.empty_cache() diff --git a/vllm/config.py b/vllm/config.py index a0991597d0673..6c24d15640e99 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -393,12 +393,6 @@ def verify_with_parallel_config( "Pipeline parallelism is only supported for the following " f" architectures: {_PP_SUPPORTED_MODELS}.") - if self.quantization == "bitsandbytes" and ( - parallel_config.tensor_parallel_size > 1 - or parallel_config.pipeline_parallel_size > 1): - raise ValueError( - "BitAndBytes quantization with TP or PP is not supported yet.") - # Remove the constraint after the bitsandbytes issue is fixed: # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308 if self.quantization == "bitsandbytes" and self.enforce_eager is False: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index cea768469aeb8..568892778abe2 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -530,8 +530,11 @@ def weight_loader(self, param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) # Special case for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 @@ -899,8 +902,13 @@ def weight_loader(self, else: shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 @@ -1000,6 +1008,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) # Special case for GGUF is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -1015,7 +1024,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data - if input_dim is not None: + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if input_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index ac869e56ce198..fd9533ab156a5 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -22,6 +22,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, SchedulerConfig) +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( @@ -689,6 +691,8 @@ def save_model( class BitsAndBytesModelLoader(BaseModelLoader): """Model loader to load model weights with BitAndBytes quantization.""" + # TODO: these module names are for Llama only, + # change so that it works with other models as well default_target_modules = [ "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", "o_proj" @@ -911,13 +915,44 @@ def _parse_quant_state(param_name: str, def _unquantized_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: from bitsandbytes.functional import quantize_4bit + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): if any(target_module in weight_name for target_module in self.target_modules): weight_name = weight_name.replace(".weight", ".qweight") + + # weight partitions of different modules occur at + # different dimensions + # TODO: these module names are for Llama only, + # change so that it works with other models as well + if 'down_proj' in weight_name or 'o_proj' in weight_name: + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., + start_index:end_index] + + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[start_index:end_index, + ...] + # bitsandbytes requires data in GPU - loaded_weight = weight_tensor.cuda().data + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + with set_default_torch_dtype(torch.float32): processed_weight, quant_state = quantize_4bit( loaded_weight, @@ -958,6 +993,13 @@ def _load_weights(self, model_config: ModelConfig, f"BitsAndBytes loader does not support {quant_method} " "quantization") + # The quant_states in pre_quantized models cannot work with a split + # weight tensor. So TP does not work with pre_quantized bnb models. + if pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequant BitsAndBytes models with TP is not supported." + "Please try with PP.") + load_8bit = False if pre_quant: load_8bit = quant_config.get('load_in_8bit', False) From a54ed8024953dc6b59906072a7a89cd4791ec4f0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 17 Sep 2024 19:50:37 +0200 Subject: [PATCH 019/116] [Model] Add mistral function calling format to all models loaded with "mistral" format (#8515) Co-authored-by: Cyrus Leung --- examples/offline_chat_with_tools.py | 138 ++++++++++++++++++ .../decoder_only/language/test_mistral.py | 67 +++++++++ vllm/entrypoints/llm.py | 6 +- vllm/entrypoints/openai/serving_chat.py | 9 +- vllm/transformers_utils/tokenizers/mistral.py | 8 +- 5 files changed, 219 insertions(+), 9 deletions(-) create mode 100644 examples/offline_chat_with_tools.py diff --git a/examples/offline_chat_with_tools.py b/examples/offline_chat_with_tools.py new file mode 100644 index 0000000000000..e69a6c067e4da --- /dev/null +++ b/examples/offline_chat_with_tools.py @@ -0,0 +1,138 @@ +# ruff: noqa +import json +import random +import string + +from vllm import LLM +from vllm.sampling_params import SamplingParams + +# This script is an offline demo for function calling +# +# If you want to run a server/client setup, please follow this code: +# +# - Server: +# +# ```bash +# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral +# ``` +# +# - Client: +# +# ```bash +# curl --location 'http://:8000/v1/chat/completions' \ +# --header 'Content-Type: application/json' \ +# --header 'Authorization: Bearer token' \ +# --data '{ +# "model": "mistralai/Mistral-7B-Instruct-v0.3" +# "messages": [ +# { +# "role": "user", +# "content": [ +# {"type" : "text", "text": "Describe this image in detail please."}, +# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}}, +# {"type" : "text", "text": "and this one as well. Answer in French."}, +# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}} +# ] +# } +# ] +# }' +# ``` +# +# Usage: +# python demo.py simple +# python demo.py advanced + +model_name = "mistralai/Mistral-7B-Instruct-v0.3" +# or switch to "mistralai/Mistral-Nemo-Instruct-2407" +# or "mistralai/Mistral-Large-Instruct-2407" +# or any other mistral model with function calling ability + +sampling_params = SamplingParams(max_tokens=8192, temperature=0.0) +llm = LLM(model=model_name, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral") + + +def generate_random_id(length=9): + characters = string.ascii_letters + string.digits + random_id = ''.join(random.choice(characters) for _ in range(length)) + return random_id + + +# simulate an API that can be called +def get_current_weather(city: str, state: str, unit: 'str'): + return (f"The weather in {city}, {state} is 85 degrees {unit}. It is " + "partly cloudly, with highs in the 90's.") + + +tool_funtions = {"get_current_weather": get_current_weather} + +tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] + +messages = [{ + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" +}] + +outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools) +output = outputs[0].outputs[0].text.strip() + +# append the assistant message +messages.append({ + "role": "assistant", + "content": output, +}) + +# let's now actually parse and execute the model's output simulating an API call by using the +# above defined function +tool_calls = json.loads(output) +tool_answers = [ + tool_funtions[call['name']](**call['arguments']) for call in tool_calls +] + +# append the answer as a tool message and let the LLM give you an answer +messages.append({ + "role": "tool", + "content": "\n\n".join(tool_answers), + "tool_call_id": generate_random_id(), +}) + +outputs = llm.chat(messages, sampling_params, tools=tools) + +print(outputs[0].outputs[0].text.strip()) +# yields +# 'The weather in Dallas, TX is 85 degrees fahrenheit. ' +# 'It is partly cloudly, with highs in the 90's.' diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 687ba6a03a691..26f90456849f1 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -4,13 +4,61 @@ """ import pytest +from vllm import SamplingParams + from ...utils import check_logprobs_close MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3", + # Mistral-Nemo is to big for CI, but passes locally + # "mistralai/Mistral-Nemo-Instruct-2407" ] +SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) + +# for function calling +TOOLS = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] +MSGS = [{ + "role": + "user", + "content": ("Can you tell me what the temperate" + " will be in Dallas, in fahrenheit?") +}] +EXPECTED_FUNC_CALL = ( + '[{"name": "get_current_weather", "arguments": ' + '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]') + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @@ -81,3 +129,22 @@ def test_mistral_format( name_0="hf", name_1="mistral", ) + + +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling +def test_mistral_function_calling( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, + dtype=dtype, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral") as vllm_model: + outputs = vllm_model.model.chat(MSGS, + tools=TOOLS, + sampling_params=SAMPLING_PARAMS) + + assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a26b721093521..248b070611cd2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,6 @@ from contextlib import contextmanager -from typing import ClassVar, List, Optional, Sequence, Union, cast, overload +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast, + overload) from tqdm import tqdm @@ -357,6 +358,7 @@ def chat( lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, add_generation_prompt: bool = True, + tools: Optional[List[Dict[str, Any]]] = None, ) -> List[RequestOutput]: """ Generate responses for a chat conversation. @@ -401,6 +403,7 @@ def chat( messages=messages, chat_template=chat_template, add_generation_prompt=add_generation_prompt, + tools=tools, ) else: prompt = apply_hf_chat_template( @@ -408,6 +411,7 @@ def chat( conversation=conversation, chat_template=chat_template, add_generation_prompt=add_generation_prompt, + tools=tools, ) inputs: PromptInputs diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 58e42fb5363fb..d28362a12abdb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -123,7 +123,8 @@ async def create_chat_completion( ] prompt: Union[str, List[int]] - if isinstance(tokenizer, MistralTokenizer): + is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) + if is_mistral_tokenizer: prompt = apply_mistral_chat_template( tokenizer, messages=request.messages, @@ -159,10 +160,10 @@ async def create_chat_completion( return self.create_error_response( "tool_choice = \"required\" is not supported!") - # "auto" tools requires --enable-auto-tool-choice - # and --tool-call-parser - if request.tool_choice == "auto" and not ( + if not is_mistral_tokenizer and request.tool_choice == "auto" and not ( self.enable_auto_tools and self.tool_parser is not None): + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( "\"auto\" tool choice requires " "--enable-auto-tool-choice and --tool-call-parser to be set") diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index ea1910ed20ec3..7a228a3efa6e8 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -165,10 +165,9 @@ def apply_chat_template(self, messages: List["ChatCompletionMessageParam"], tools: Optional[Dict[str, Any]] = None, **kwargs) -> List[int]: - assert tools is None, "`tools` are not yet supported." - request = ChatCompletionRequest( - messages=messages) # type: ignore[type-var] + request = ChatCompletionRequest(messages=messages, + tools=tools) # type: ignore[type-var] encoded = self.mistral.encode_chat_completion(request) # encode-decode to get clean prompt @@ -176,7 +175,8 @@ def apply_chat_template(self, def convert_tokens_to_string(self, tokens: List[str]) -> str: if isinstance(self.tokenizer, Tekkenizer): - return "".join(tokens) + return "".join(t for t in tokens + if t not in self.tokenizer._all_special_tokens) else: return self.tokenizer.decode(tokens) # type: ignore[arg-type] From 56c3de018c35580fd088655c2f9951cd4da5335d Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 17 Sep 2024 20:24:29 +0100 Subject: [PATCH 020/116] [Misc] Don't dump contents of kvcache tensors on errors (#8527) --- vllm/worker/model_runner_base.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 94d2507968382..975b88c0e79a2 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -3,11 +3,13 @@ from abc import ABC, abstractmethod from datetime import datetime from functools import wraps -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, - TypeVar) +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, + Optional, Type, TypeVar) import torch +from torch import is_tensor +from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceGroupMetadata @@ -17,6 +19,8 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata +logger = init_logger(__name__) + T = TypeVar('T', bound="BroadcastableModelInput") @@ -113,6 +117,8 @@ def _wrapper(*args, **kwargs): except Exception as err: timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl" + logger.info("Writing input of failed execution to %s...", + filename) with open(filename, "wb") as filep: dumped_inputs = { k: v @@ -122,7 +128,19 @@ def _wrapper(*args, **kwargs): for i, arg in enumerate(args): if i not in (exclude_args or []): dumped_inputs[f"arg_{i}"] = arg + + # Only persist dtype and shape for kvcache tensors + # (can be way to big otherwise) + if (kv_caches := dumped_inputs.get("kv_caches")) \ + and isinstance(kv_caches, Iterable): + dumped_inputs["kv_caches"] = [(t.dtype, t.shape) + for t in kv_caches + if is_tensor(t)] + pickle.dump(dumped_inputs, filep) + logger.info( + "Completed writing input of failed execution to %s.", + filename) raise type(err)( f"Error in model execution (input dumped to {filename}): " f"{str(err)}") from err From 98f9713399bd602ff954a83e6e6abcb4cf8b8864 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 17 Sep 2024 17:17:08 -0600 Subject: [PATCH 021/116] [Bugfix] Fix TP > 1 for new granite (#8544) Signed-off-by: Joe Runde --- vllm/model_executor/models/granite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index b0325e8b616c8..5f365bbc30670 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -428,7 +428,8 @@ def compute_logits( sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) - logits /= self.config.logits_scaling + if logits is not None: + logits /= self.config.logits_scaling return logits def sample( From fa0c114fad4e2b807503e78d5110558cfee92ba4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 17 Sep 2024 16:24:06 -0700 Subject: [PATCH 022/116] [doc] improve installation doc (#8550) Co-authored-by: Andy Dai <76841985+Imss27@users.noreply.github.com> --- docs/source/getting_started/installation.rst | 2 ++ tests/compile/test_full_graph.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 50a761b49490c..0322503a89a56 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -95,6 +95,8 @@ You can also build and install vLLM from source: $ export MAX_JOBS=6 $ pip install -e . + This is especially useful when you are building on less powerful machines. For example, when you use WSL, it only `gives you half of the memory by default `_, and you'd better use ``export MAX_JOBS=1`` to avoid compiling multiple files simultaneously and running out of memory. The side effect is that the build process will be much slower. If you only touch the Python code, slow compilation is okay, as you are building in an editable mode: you can just change the code and run the Python script without any re-compilation or re-installation. + .. tip:: If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image. diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 6fc445539bbbe..2e309aaa58d48 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -28,7 +28,10 @@ def test_full_graph(model, tp_size): "The future of AI is", ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size) + llm = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tp_size, + disable_custom_all_reduce=True) outputs = llm.generate(prompts, sampling_params) From 09deb4721f830602d0417604c7e18b7e384f9594 Mon Sep 17 00:00:00 2001 From: "Alexey Kondratiev(AMD)" <143633163+alexeykondrat@users.noreply.github.com> Date: Tue, 17 Sep 2024 19:40:29 -0400 Subject: [PATCH 023/116] [CI/Build] Excluding kernels/test_gguf.py from ROCm (#8520) --- .buildkite/run-amd-test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 6659440135ff4..9274a30e04325 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -83,6 +83,7 @@ if [[ $commands == *" kernels "* ]]; then --ignore=kernels/test_encoder_decoder_attn.py \ --ignore=kernels/test_flash_attn.py \ --ignore=kernels/test_flashinfer.py \ + --ignore=kernels/test_gguf.py \ --ignore=kernels/test_int8_quant.py \ --ignore=kernels/test_machete_gemm.py \ --ignore=kernels/test_mamba_ssm.py \ From 8110e44529f431d54b02060528601c0d3e3f7d02 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 17 Sep 2024 19:44:27 -0400 Subject: [PATCH 024/116] [Kernel] Change interface to Mamba causal_conv1d_update for continuous batching (#8012) --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 30 +++++++++- csrc/mamba/causal_conv1d/causal_conv1d.h | 4 ++ csrc/ops.h | 9 ++- csrc/torch_bindings.cpp | 5 +- tests/kernels/test_causal_conv1d.py | 58 +++++++++++++++++++ vllm/_custom_ops.py | 14 +++-- .../layers/mamba/ops/causal_conv1d.py | 10 +++- 7 files changed, 114 insertions(+), 16 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 88a64a8ece585..32261ec17d897 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x, const at::Tensor &conv_state, const at::Tensor &weight, const c10::optional &bias_, - bool silu_activation) { + bool silu_activation, + const c10::optional &conv_state_indices_) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x, const int width = weight.size(-1); CHECK_SHAPE(x, batch_size, dim); - CHECK_SHAPE(conv_state, batch_size, dim, width); CHECK_SHAPE(weight, dim, width); TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); @@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x, params.conv_state_c_stride = conv_state.stride(1); params.conv_state_l_stride = conv_state.stride(2); + if (conv_state_indices_.has_value()) { + auto conv_state_indices = conv_state_indices_.value(); + TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) + TORCH_CHECK(conv_state_indices.is_cuda()); + TORCH_CHECK(conv_state_indices.stride(0) == 1) + CHECK_SHAPE(conv_state_indices, batch_size); + + int conv_state_entries = conv_state.size(0); + CHECK_SHAPE(conv_state, conv_state_entries, dim, width); + + params.conv_state_indices_ptr = conv_state_indices.data_ptr(); + } else { + CHECK_SHAPE(conv_state, batch_size, dim, width); + params.conv_state_indices_ptr = nullptr; + } + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)x.get_device()}; @@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { const int channel_id = blockIdx.y * kNThreads + tidx; input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride; - input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride + + // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor + // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. + const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr + ? batch_id + : params.conv_state_indices_ptr[batch_id]; + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + channel_id * params.out_c_stride; diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index bb25314c8bbbd..32a7d83c09b8d 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -36,6 +36,10 @@ struct ConvParamsBase { void *__restrict__ conv_state_ptr; + // For the continuous batching case. Makes it so that the mamba state for + // the current batch doesn't need to be a contiguous tensor. + int32_t *__restrict__ conv_state_indices_ptr; + void *__restrict__ seq_idx_ptr; // No __restrict__ since initial_states could be the same as final_states. diff --git a/csrc/ops.h b/csrc/ops.h index ee89ad32cb025..15e9ebe87408a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -222,11 +222,10 @@ std::vector selective_scan_fwd( const c10::optional& index_, const c10::optional& x); -at::Tensor causal_conv1d_update(const at::Tensor& x, - const at::Tensor& conv_state, - const at::Tensor& weight, - const c10::optional& bias_, - bool silu_activation); +at::Tensor causal_conv1d_update( + const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, + const c10::optional& bias, bool silu_activation, + const c10::optional& conv_state_indices); at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional& bias_, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7009180a8687c..045203c3de8a8 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -279,8 +279,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "causal_conv1d_update(Tensor! x," "Tensor! conv_state," "Tensor! weight," - "Tensor? bias_," - "bool silu_activation) -> Tensor"); + "Tensor? bias," + "bool silu_activation," + "Tensor? conv_state_indices) -> Tensor"); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.def( diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 7bf338b36953a..344e07e739454 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -203,3 +203,61 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("seqlen", [1, 4, 5]) +@pytest.mark.parametrize("width", [2, 3, 4]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, + silu_activation, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + + # set seed + torch.random.manual_seed(0) + batch = 64 + + x = torch.randn(batch, dim, device=device, dtype=itype) + + total_entries = 10 * batch + conv_state = torch.randn(total_entries, + dim, + width, + device=device, + dtype=itype) + conv_state_indices = torch.randperm(total_entries)[:batch].to( + dtype=torch.int32, device=device) + + weight = torch.randn(dim, + width, + device=device, + dtype=itype, + requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) + else: + bias = None + conv_state_ref = conv_state[conv_state_indices, :].detach().clone() + activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=conv_state_indices) + out_ref = causal_conv1d_update_ref(x, + conv_state_ref, + weight, + bias, + activation=activation) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ac90895b11c37..ff5aa8bee3c27 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -768,11 +768,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, silu_activation) -def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, bias_: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias_: Optional[torch.Tensor], + silu_activation: bool, + conv_state_indices: Optional[torch.Tensor], +) -> torch.Tensor: return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation) + silu_activation, + conv_state_indices) def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 413c8bc227ae8..196d81267f32f 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py from typing import Optional @@ -70,12 +71,17 @@ def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, - activation: Optional[str] = None): + activation: Optional[str] = None, + conv_state_indices: Optional[torch.Tensor] = None): """ x: (batch, dim) conv_state: (batch, dim, width) weight: (dim, width) bias: (dim,) + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. out: (batch, dim) """ @@ -83,4 +89,4 @@ def causal_conv1d_update(x: torch.Tensor, raise NotImplementedError("activation must be None, silu, or swish") activation_bool = activation in ["silu", "swish"] return ops.causal_conv1d_update(x, conv_state, weight, bias, - activation_bool) + activation_bool, conv_state_indices) From 95965d31b6ac2c9557816a6ffabe4a3117a5ccb2 Mon Sep 17 00:00:00 2001 From: Daniele <36171005+dtrifiro@users.noreply.github.com> Date: Wed, 18 Sep 2024 04:49:53 +0200 Subject: [PATCH 025/116] [CI/Build] fix Dockerfile.cpu on podman (#8540) --- Dockerfile.cpu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 34b4c95e34ffc..4d7289366296b 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -24,6 +24,8 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl +WORKDIR /workspace + ENV PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \ From e351572900f7d87e14fe203ea3a49c1c7ddae0d6 Mon Sep 17 00:00:00 2001 From: Jiaxin Shan Date: Wed, 18 Sep 2024 02:51:59 -0700 Subject: [PATCH 026/116] [Misc] Add argument to disable FastAPI docs (#8554) --- vllm/entrypoints/openai/api_server.py | 8 +++++++- vllm/entrypoints/openai/cli_args.py | 7 +++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3d1d832986c1e..b891debfd2b91 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -417,7 +417,13 @@ async def unload_lora_adapter(request: UnloadLoraAdapterRequest, def build_app(args: Namespace) -> FastAPI: - app = FastAPI(lifespan=lifespan) + if args.disable_fastapi_docs: + app = FastAPI(openapi_url=None, + docs_url=None, + redoc_url=None, + lifespan=lifespan) + else: + app = FastAPI(lifespan=lifespan) app.include_router(router) app.root_path = args.root_path diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 7ccee0b6b55b7..bbb0823de9a51 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -190,6 +190,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'ID numbers being printed in log.' '\n\nDefault: Unlimited') + parser.add_argument( + "--disable-fastapi-docs", + action='store_true', + default=False, + help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" + ) + return parser From 6ffa3f314c59e42238f1c5f923ff2839e0af9698 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 18 Sep 2024 18:38:11 +0800 Subject: [PATCH 027/116] [CI/Build] Avoid CUDA initialization (#8534) --- benchmarks/kernels/benchmark_layernorm.py | 9 +-- benchmarks/kernels/benchmark_moe.py | 6 +- .../kernels/benchmark_paged_attention.py | 7 +-- benchmarks/kernels/benchmark_quant.py | 9 +-- benchmarks/kernels/benchmark_rope.py | 6 +- tests/kernels/test_activation.py | 9 +-- tests/kernels/test_attention.py | 18 ++---- tests/kernels/test_attention_selector.py | 2 +- tests/kernels/test_awq_triton.py | 5 +- tests/kernels/test_blocksparse_attention.py | 12 +--- tests/kernels/test_cache.py | 25 +++----- tests/kernels/test_causal_conv1d.py | 5 +- tests/kernels/test_cutlass.py | 11 ++-- tests/kernels/test_flash_attn.py | 5 +- tests/kernels/test_flashinfer.py | 10 +-- tests/kernels/test_fp8_quant.py | 10 ++- tests/kernels/test_gguf.py | 5 +- tests/kernels/test_int8_quant.py | 13 ++-- tests/kernels/test_layernorm.py | 5 +- tests/kernels/test_machete_gemm.py | 2 +- tests/kernels/test_mamba_ssm.py | 5 +- tests/kernels/test_moe.py | 3 +- tests/kernels/test_pos_encoding.py | 14 ++--- tests/kernels/test_prefix_prefill.py | 12 +--- tests/lora/test_layers.py | 5 +- tests/lora/test_punica_sizes.py | 18 ++---- tests/lora/test_punica_variation.py | 18 ++---- .../decoder_only/language/test_granite.py | 9 +-- tests/quantization/test_fp8.py | 4 +- tests/quantization/utils.py | 8 ++- vllm/attention/backends/rocm_flash_attn.py | 3 +- .../ops/blocksparse_attention/interface.py | 5 +- vllm/attention/ops/prefix_prefill.py | 3 +- vllm/attention/selector.py | 4 +- vllm/config.py | 12 ++-- vllm/distributed/parallel_state.py | 3 +- vllm/envs.py | 1 + .../compressed_tensors/compressed_tensors.py | 6 +- .../layers/quantization/fbgemm_fp8.py | 4 +- .../model_executor/layers/quantization/fp8.py | 5 +- .../layers/quantization/utils/marlin_utils.py | 10 +-- .../quantization/utils/marlin_utils_fp8.py | 3 +- .../layers/quantization/utils/w8a8_utils.py | 5 +- vllm/model_executor/model_loader/loader.py | 6 +- vllm/model_executor/models/qwen2_vl.py | 2 +- vllm/model_executor/utils.py | 10 +-- vllm/platforms/cpu.py | 8 +-- vllm/platforms/cuda.py | 17 ++--- vllm/platforms/interface.py | 62 ++++++++++++++++--- vllm/platforms/rocm.py | 14 ++--- vllm/platforms/tpu.py | 8 ++- vllm/prompt_adapter/utils.py | 4 +- vllm/usage/usage_lib.py | 3 +- vllm/utils.py | 28 ++++++--- vllm/worker/worker.py | 16 +++-- 55 files changed, 256 insertions(+), 256 deletions(-) diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index 4947fda02e1cc..92f6053cc6d7e 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -1,10 +1,10 @@ -import random import time import torch from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, + seed_everything) @torch.inference_mode() @@ -16,10 +16,7 @@ def main(num_tokens: int, do_profile: bool = False, num_warmup_iters: int = 5, num_iters: int = 100) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device("cuda") layer = RMSNorm(hidden_size).to(dtype=dtype) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index fd233c71b10a6..c2ad98b7e2656 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -10,7 +10,7 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, seed_everything class BenchmarkConfig(TypedDict): @@ -166,7 +166,7 @@ class BenchmarkWorker: def __init__(self, seed: int) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(seed) + seed_everything(seed) self.seed = seed def benchmark( @@ -180,7 +180,7 @@ def benchmark( use_fp8_w8a8: bool, use_int8_w8a16: bool, ) -> Tuple[Dict[str, int], float]: - torch.cuda.manual_seed_all(self.seed) + seed_everything(self.seed) dtype_str = get_config_dtype_str(dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index a04433142da42..87864d038d593 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -6,7 +6,7 @@ from vllm import _custom_ops as ops from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - create_kv_caches_with_random) + create_kv_caches_with_random, seed_everything) NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -28,10 +28,7 @@ def main( device: str = "cuda", kv_cache_dtype: Optional[str] = None, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) scale = float(1.0 / (head_size**0.5)) query = torch.empty(num_seqs, diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index 4c1a7b26213a5..743a5744e8614 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -1,10 +1,10 @@ -import random import time import torch from vllm import _custom_ops as ops -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, + seed_everything) @torch.inference_mode() @@ -17,10 +17,7 @@ def main(num_tokens: int, do_profile: bool = False, num_warmup_iters: int = 5, num_iters: int = 100) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device("cuda") x = torch.randn(num_tokens, hidden_size, dtype=dtype) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index f542684a9a2a9..73fc9e9dbf461 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -6,7 +6,7 @@ from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, get_rope) -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, seed_everything def benchmark_rope_kernels_multi_lora( @@ -22,9 +22,7 @@ def benchmark_rope_kernels_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index ed050ce851535..9b476585fa19e 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, NewGELU, QuickGELU, SiluAndMul) +from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -34,9 +35,7 @@ def test_act_and_mul( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) x = torch.randn(num_tokens, 2 * d, dtype=dtype) if activation == "silu": @@ -77,9 +76,7 @@ def test_activation( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) x = torch.randn(num_tokens, d, dtype=dtype) layer = activation[0]() diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 46831b506aff3..4bd6f7863a658 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -6,7 +6,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops -from vllm.utils import get_max_shared_memory_bytes, is_hip +from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -139,10 +139,8 @@ def test_paged_attention( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -354,10 +352,7 @@ def test_paged_attention_rocm( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -506,10 +501,7 @@ def test_multi_query_kv_attention( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index a20a741c27f74..c1fb45955a0e5 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -45,7 +45,7 @@ def test_flash_attn(monkeypatch): override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + with patch("torch.cuda.get_device_capability", return_value=(7, 5)): backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) assert backend.name != STR_FLASH_ATTN_VAL diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py index 198d40a155ccb..e95e5bd948212 100644 --- a/tests/kernels/test_awq_triton.py +++ b/tests/kernels/test_awq_triton.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.quantization.awq_triton import ( AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) +from vllm.utils import seed_everything device = "cuda" @@ -79,7 +80,7 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): zeros_cols = qweight_cols zeros_dtype = torch.int32 - torch.manual_seed(0) + seed_everything(0) qweight = torch.randint(0, torch.iinfo(torch.int32).max, @@ -133,7 +134,7 @@ def test_gemm(N, K, M, splitK, group_size): qzeros_rows = scales_rows qzeros_cols = qweight_cols - torch.manual_seed(0) + seed_everything(0) input = torch.rand((input_rows, input_cols), dtype=input_dtype, diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index 7357508751ae1..f3bd8f0524264 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -7,7 +7,7 @@ from vllm import _custom_ops as ops from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn) -from vllm.utils import get_max_shared_memory_bytes, is_hip +from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -172,10 +172,7 @@ def test_paged_attention( blocksparse_block_size: int, blocksparse_head_sliding_step: int, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -386,10 +383,7 @@ def test_varlen_blocksparse_attention_prefill( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 19402a337b8d6..b0e7097fdfbd4 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -6,6 +6,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops +from vllm.utils import seed_everything COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -55,10 +56,7 @@ def test_copy_blocks( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Generate random block mappings where each source block is mapped to two # destination blocks. @@ -134,10 +132,7 @@ def test_reshape_and_cache( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks @@ -229,9 +224,7 @@ def test_reshape_and_cache_flash( device: str, kv_cache_dtype: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Create a random slot mapping. @@ -345,10 +338,8 @@ def test_swap_blocks( pytest.skip() if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) src_device = device if direction[0] == "cuda" else 'cpu' dst_device = device if direction[1] == "cuda" else 'cpu' @@ -417,9 +408,7 @@ def test_fp8_e4m3_conversion( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) low = -224.0 high = 224.0 diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 344e07e739454..043c4923bd660 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) +from vllm.utils import seed_everything def causal_conv1d_ref( @@ -104,7 +105,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed - torch.random.manual_seed(0) + seed_everything(0) if not channel_last: x = torch.randn(batch, 4096 + dim + 64, @@ -175,7 +176,7 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed - torch.random.manual_seed(0) + seed_everything(0) batch = 2 x = torch.randn(batch, dim, device=device, dtype=itype) conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index d1f0524f83c4c..cc4ca2e91e76f 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -15,9 +15,6 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -capability = current_platform.get_device_capability() -capability = capability[0] * 10 + capability[1] - def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) @@ -119,7 +116,7 @@ def cutlass_int8_gemm_helper(m: int, @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool): @@ -157,7 +154,7 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, out_dtype: Type[torch.dtype], @@ -175,7 +172,7 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, use_bias: bool, device: str): @@ -207,7 +204,7 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool, use_bias: bool): diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 870a8bf65eb92..8e960d098c408 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,6 +4,7 @@ import torch import vllm.attention.backends.flash_attn # noqa: F401 +from vllm.utils import seed_everything NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] @@ -87,7 +88,7 @@ def test_flash_attn_with_paged_kv( num_blocks: int, ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] @@ -174,7 +175,7 @@ def test_varlen_with_paged_kv( num_blocks: int, ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index 696cc0c6cdf10..80a388db6530e 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -4,6 +4,8 @@ import pytest import torch +from vllm.utils import seed_everything + NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] @@ -82,7 +84,7 @@ def test_flashinfer_decode_with_paged_kv( soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] @@ -168,7 +170,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], block_size: int, soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -266,7 +268,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( head_size: int, dtype: torch.dtype, block_size: int, soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -379,7 +381,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( ) -> None: # test doesn't work for num_heads = (16,16) torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index bae9b39203ff9..49f5ce53aab54 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -5,6 +5,7 @@ from tests.kernels.quant_utils import (FP8_DTYPE, ref_dynamic_per_tensor_fp8_quant, ref_dynamic_per_token_quant) +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, @@ -24,8 +25,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 # avoid nans @@ -49,8 +49,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") @@ -67,8 +66,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() @pytest.mark.parametrize("seed", SEEDS) def test_fp8_quant_large(seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings hidden_size = 1152 # Smallest hidden_size to reproduce the error diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py index ee29ed93b61fc..1513fc196153c 100644 --- a/tests/kernels/test_gguf.py +++ b/tests/kernels/test_gguf.py @@ -7,6 +7,7 @@ from huggingface_hub import snapshot_download import vllm._custom_ops as ops +from vllm.utils import seed_everything GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") @@ -74,7 +75,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype, @torch.inference_mode() def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType): - torch.cuda.manual_seed_all(0) + seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((1, hidden_size), dtype=dtype, device="cuda") @@ -110,7 +111,7 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype, @torch.inference_mode() def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType): - torch.cuda.manual_seed_all(0) + seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda") diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index e93cb535d715a..41e103e1d09f9 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -4,6 +4,7 @@ from tests.kernels.quant_utils import ref_dynamic_per_token_quant from tests.kernels.utils import opcheck from vllm._custom_ops import scaled_int8_quant +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -44,8 +45,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True): @torch.inference_mode() def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 @@ -68,8 +68,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, @@ -113,8 +112,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 @@ -140,8 +138,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float, azp: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 6eaf67ec75f41..382079d472ee9 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -3,6 +3,7 @@ from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing @@ -30,9 +31,7 @@ def test_rms_norm( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) layer = RMSNorm(hidden_size).to(dtype=dtype) layer.weight.data.normal_(mean=1.0, std=0.1) diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py index ce65aaef60ac6..0a90882223077 100644 --- a/tests/kernels/test_machete_gemm.py +++ b/tests/kernels/test_machete_gemm.py @@ -48,7 +48,7 @@ # `is_quant_method_supported` conflates kernels with quantization methods # an assumption which is breaking down as quantizations methods can have # have kernels and some kernels support multiple quantization methods. -IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) def rand_data(shape, dtype=torch.float16): diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index d3cb0a8656a02..f582445692344 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -5,6 +5,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) +from vllm.utils import seed_everything def selective_state_update_ref(state, @@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, rtolw = max(rtolw, rtol) atolw = max(atolw, atol) # set seed - torch.random.manual_seed(0) + seed_everything(0) batch_size = 2 dim = 4 dstate = 8 @@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): if torch.version.hip: atol *= 2 # set seed - torch.random.manual_seed(0) + seed_everything(0) batch_size = 1 state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) x = torch.randn(batch_size, dim, device=device, dtype=itype) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 8072cf09e5b65..b1f0516dfa0b3 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -18,6 +18,7 @@ marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE from vllm.scalar_type import scalar_types +from vllm.utils import seed_everything def torch_moe(a, w1, w2, score, topk): @@ -151,7 +152,7 @@ def test_fused_marlin_moe( act_order: bool, num_bits: int, ): - torch.manual_seed(7) + seed_everything(7) if topk > e: return diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 65242e275650c..ba9d2d4389b21 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -5,6 +5,7 @@ import torch from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -46,9 +47,8 @@ def test_rotary_embedding( ) -> None: if rotary_dim is None: rotary_dim = head_size - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -100,9 +100,7 @@ def test_batched_rotary_embedding( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -162,9 +160,7 @@ def test_batched_rotary_embedding_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 60f9a4dc9f90f..3181d92562399 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -9,7 +9,7 @@ from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.ops.prefix_prefill import context_attention_fwd -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] @@ -39,10 +39,7 @@ def test_contexted_kv_attention( kv_cache_dtype: str, device: str, ) -> None: - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) + seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process @@ -237,10 +234,7 @@ def test_contexted_kv_attention_alibi( kv_cache_dtype: str, device: str, ) -> None: - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) + seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index effcffc5c174e..e3233c6b60696 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) from vllm.model_executor.utils import set_random_seed +from vllm.utils import seed_everything from .utils import DummyLoRAManager @@ -922,9 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, seq_len) -> None: dtype = torch.float16 seed = 0 - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index c36fb3afb0cc3..314d6215cbd9c 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -4,7 +4,6 @@ whether the corresponding Triton kernel can run normally when tensor parallelism is set to [1, 2, 4, 8, 16, 32, 64]. """ -import random from unittest.mock import patch import pytest @@ -17,6 +16,7 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.triton_utils.libentry import LibEntry +from vllm.utils import seed_everything from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -145,11 +145,8 @@ def test_punica_sgmv( seed: int, device: str, ): - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 128 ( @@ -238,11 +235,8 @@ def test_punica_bgmv( from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 1 ( @@ -329,11 +323,9 @@ def test_punica_expand_nslices( ): from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) + seq_length = 128 if op_type == "sgmv" else 1 ( inputs_tensor, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index d026e34878e04..28a395af19e6d 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -3,7 +3,6 @@ under different conditions, including various batches, numbers of LoRA , and maximum ranks. """ -import random from unittest.mock import patch import pytest @@ -16,6 +15,7 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.triton_utils.libentry import LibEntry +from vllm.utils import seed_everything from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -60,11 +60,8 @@ def test_punica_sgmv( seed: int, device: str, ): - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 128 ( @@ -153,11 +150,8 @@ def test_punica_bgmv( from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 1 ( @@ -244,11 +238,9 @@ def test_punica_expand_nslices( ): from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) + seq_length = 128 if op_type == "sgmv" else 1 ( inputs_tensor, diff --git a/tests/models/decoder_only/language/test_granite.py b/tests/models/decoder_only/language/test_granite.py index 82c753855e714..e5c5ce4a8f745 100644 --- a/tests/models/decoder_only/language/test_granite.py +++ b/tests/models/decoder_only/language/test_granite.py @@ -2,23 +2,18 @@ Run `pytest tests/models/test_granite.py`. """ -import importlib.metadata - import pytest +import transformers from ...utils import check_logprobs_close -TRANSFORMERS_VERSION = tuple( - map(int, - importlib.metadata.version("transformers").split("."))) - MODELS = [ "ibm/PowerLM-3b", ] # GraniteForCausalLM will be in transformers >= 4.45 -@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45), +@pytest.mark.skipif(transformers.__version__ < "4.45", reason="granite model test requires transformers >= 4.45") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 58864e83173f9..a0c1d7e24c503 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, assert attn._k_scale == 1.0 assert attn._v_scale == 1.0 - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability >= 89 and not force_marlin: + if current_platform.has_device_capability(89) and not force_marlin: # For GPUs with hardware support, we keep weights in fp8 assert fc1.weight.dtype == torch.float8_e4m3fn else: diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 5fad06878f4a3..061a077592e80 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -8,6 +8,8 @@ def is_quant_method_supported(quant_method: str) -> bool: return False capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - return (capability >= - QUANTIZATION_METHODS[quant_method].get_min_capability()) + assert capability is not None + + min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability() + + return capability.to_int() >= min_capability diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index f1404b8b6bfe7..6bd276ade1d41 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -13,6 +13,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -299,7 +300,7 @@ def __init__( else: # if not using triton, navi3x/navi21/navi10 do not use flash-attn # either - if torch.cuda.get_device_capability()[0] != 9: + if not current_platform.has_device_capability(90): self.use_naive_attn = True else: try: diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py index e870a8e614d12..1ead541f391b5 100644 --- a/vllm/attention/ops/blocksparse_attention/interface.py +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -8,8 +8,7 @@ from .utils import (dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask) -IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() - and current_platform.get_device_capability()[0] >= 8) +IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) if IS_COMPUTE_8_OR_ABOVE: from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd @@ -36,7 +35,7 @@ def __init__( use_spda = is_hip() or is_cpu() or not \ IS_COMPUTE_8_OR_ABOVE device = device or (torch.cuda.current_device() - if torch.cuda.is_available() else "cpu") + if current_platform.is_cuda_alike() else "cpu") device = torch.device(device) # NOTE: vllm CPU backend support BF16 instead of FP16. dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 558b2f3eeac7e..a2a649c8ebcfd 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -709,8 +709,7 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None): - cap = current_platform.get_device_capability() - BLOCK = 128 if cap[0] >= 8 else 64 + BLOCK = 128 if current_platform.has_device_capability(80) else 64 NUM_WARPS = 8 # need to reduce num. blocks when using fp32 diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 855586d4e5961..fbda263ba8e08 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -203,7 +203,7 @@ def which_attn_to_use( selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if selected_backend == _Backend.ROCM_FLASH: - if current_platform.get_device_capability()[0] != 9: + if not current_platform.has_device_capability(90): # not Instinct series GPUs. logger.info("flash_attn is not supported on NAVI GPUs.") else: @@ -212,7 +212,7 @@ def which_attn_to_use( # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: - if current_platform.get_device_capability()[0] < 8: + if not current_platform.has_device_capability(80): # Volta and Turing NVIDIA GPUs. logger.info( "Cannot use FlashAttention-2 backend for Volta and Turing " diff --git a/vllm/config.py b/vllm/config.py index 6c24d15640e99..9d42b75c1c462 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -17,7 +17,7 @@ get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - is_cpu, is_hip, is_neuron, is_openvino, is_xpu, + is_hip, is_neuron, is_openvino, is_xpu, print_warning_once) if TYPE_CHECKING: @@ -1035,20 +1035,20 @@ class DeviceConfig: def __init__(self, device: str = "auto") -> None: if device == "auto": # Automated device type detection - if is_neuron(): + if current_platform.is_cuda_alike(): + self.device_type = "cuda" + elif is_neuron(): self.device_type = "neuron" elif is_openvino(): self.device_type = "openvino" elif current_platform.is_tpu(): self.device_type = "tpu" - elif is_cpu(): + elif current_platform.is_cpu(): self.device_type = "cpu" elif is_xpu(): self.device_type = "xpu" else: - # We don't call torch.cuda.is_available() here to - # avoid initializing CUDA before workers are forked - self.device_type = "cuda" + raise RuntimeError("Failed to infer device type") else: # Device type is assigned explicitly self.device_type = device diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 1c864bcd5d708..df07842edfa56 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -35,6 +35,7 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform @dataclass @@ -191,7 +192,7 @@ def __init__( assert self.cpu_group is not None assert self.device_group is not None - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") diff --git a/vllm/envs.py b/vllm/envs.py index 2003ede95d2d8..6edb06ecd2e20 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -60,6 +60,7 @@ VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b5b2570966600..ab8207f128348 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -116,10 +116,10 @@ def get_config_filenames(cls) -> List[str]: def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: - capability = current_platform.get_device_capability() # type: ignore + capability_tuple = current_platform.get_device_capability() - if capability is not None: - capability = capability[0] * 10 + capability[1] + if capability_tuple is not None: + capability = capability_tuple.to_int() supported = capability >= min_capability if error and not supported: raise RuntimeError( diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 3ccf1af9eb898..eb59344f36d2e 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -32,9 +32,7 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 + self.use_marlin = not current_platform.has_device_capability(89) @classmethod def get_name(cls) -> str: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 32affe06b89b7..b5feb55db0e74 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -120,9 +120,8 @@ def __init__(self, quant_config: Fp8Config): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm if is_hip(): self.use_marlin = False diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 699d5f1844146..fea94cf7322ad 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool, device_capability: Optional[int] = None ): if device_capability is None: - major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) if device_capability < 80: return [] @@ -52,8 +53,9 @@ def _check_marlin_supported( device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: if device_capability is None: - major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) supported_types = query_marlin_supported_quant_types( has_zp, device_capability) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 5f9d8658a342f..8b3dfaae971c3 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -10,8 +10,7 @@ def is_fp8_marlin_supported(): - capability = current_platform.get_device_capability() - return capability[0] >= 8 + return current_platform.has_device_capability(80) def apply_fp8_marlin_linear( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 887ee6605560c..d86fea63d8a1b 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -17,8 +17,9 @@ def cutlass_fp8_supported() -> bool: # cutlass is not supported on Rocm if is_hip(): return False - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() return ops.cutlass_scaled_mm_supports_fp8(capability) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index fd9533ab156a5..f0d2a9e7f06be 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -97,10 +97,10 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - capability = current_platform.get_device_capability() # type: ignore + capability_tuple = current_platform.get_device_capability() - if capability is not None: - capability = capability[0] * 10 + capability[1] + if capability_tuple is not None: + capability = capability_tuple.to_int() if capability < quant_config.get_min_capability(): raise ValueError( f"The quantization method {model_config.quantization} " diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 179399a12a3d5..a9a0329e99f08 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -207,7 +207,7 @@ def __init__( selected_backend = backend_name_to_enum(backend_by_env_var) if selected_backend is None: # For Volta and Turing GPUs, use xformers instead. - device_available = current_platform.get_device_capability()[0] >= 8 + device_available = current_platform.has_device_capability(80) if device_available: from transformers.utils import is_flash_attn_2_available diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 336bc1cd005cf..d7eec818cbba4 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -1,17 +1,13 @@ """Utils for model executor.""" -import random from typing import Any, Dict, Optional -import numpy as np import torch +from vllm.utils import seed_everything + def set_random_seed(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + seed_everything(seed) def set_weight_attrs( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 4736e898b6a52..9b348f3e17a5f 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -6,10 +6,10 @@ class CpuPlatform(Platform): _enum = PlatformEnum.CPU - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: return "cpu" - @staticmethod - def inference_mode(): + @classmethod + def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8d18527e7c973..a9978d5d84d7c 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -11,7 +11,7 @@ from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum logger = init_logger(__name__) @@ -96,19 +96,20 @@ def device_id_to_physical_device_id(device_id: int) -> int: class CudaPlatform(Platform): _enum = PlatformEnum.CUDA - @staticmethod - def get_device_capability(device_id: int = 0) -> Tuple[int, int]: + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: physical_device_id = device_id_to_physical_device_id(device_id) - return get_physical_device_capability(physical_device_id) + major, minor = get_physical_device_capability(physical_device_id) + return DeviceCapability(major=major, minor=minor) - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) return get_physical_device_name(physical_device_id) - @staticmethod + @classmethod @with_nvml_context - def is_full_nvlink(physical_device_ids: List[int]) -> bool: + def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 676f4c9fccf5a..360590d7d5eb6 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,5 +1,5 @@ import enum -from typing import Optional, Tuple +from typing import NamedTuple, Optional, Tuple, Union import torch @@ -12,6 +12,23 @@ class PlatformEnum(enum.Enum): UNSPECIFIED = enum.auto() +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + class Platform: _enum: PlatformEnum @@ -27,16 +44,47 @@ def is_tpu(self) -> bool: def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU - @staticmethod - def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]: + def is_cuda_alike(self) -> bool: + """Stateless version of :func:`torch.cuda.is_available`.""" + return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + + @classmethod + def get_device_capability( + cls, + device_id: int = 0, + ) -> Optional[DeviceCapability]: + """Stateless version of :func:`torch.cuda.get_device_capability`.""" return None - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def has_device_capability( + cls, + capability: Union[Tuple[int, int], int], + device_id: int = 0, + ) -> bool: + """ + Test whether this platform is compatible with a device capability. + + The ``capability`` argument can either be: + + - A tuple ``(major, minor)``. + - An integer ````. (See :meth:`DeviceCapability.to_int`) + """ + current_capability = cls.get_device_capability(device_id=device_id) + if current_capability is None: + return False + + if isinstance(capability, tuple): + return current_capability >= capability + + return current_capability.to_int() >= capability + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: raise NotImplementedError - @staticmethod - def inference_mode(): + @classmethod + def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. This wrapper is recommended because some hardware backends such as TPU diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 28525e8ff8811..b6a19eca01745 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,12 +1,11 @@ import os from functools import lru_cache -from typing import Tuple import torch from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum logger = init_logger(__name__) @@ -20,12 +19,13 @@ class RocmPlatform(Platform): _enum = PlatformEnum.ROCM - @staticmethod + @classmethod @lru_cache(maxsize=8) - def get_device_capability(device_id: int = 0) -> Tuple[int, int]: - return torch.cuda.get_device_capability(device_id) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) - @staticmethod + @classmethod @lru_cache(maxsize=8) - def get_device_name(device_id: int = 0) -> str: + def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 393fc230da0b9..b30bccb103af3 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -6,6 +6,10 @@ class TpuPlatform(Platform): _enum = PlatformEnum.TPU - @staticmethod - def inference_mode(): + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + def inference_mode(cls): return torch.no_grad() diff --git a/vllm/prompt_adapter/utils.py b/vllm/prompt_adapter/utils.py index 989cc5a0f87c8..4cde2a0254b90 100644 --- a/vllm/prompt_adapter/utils.py +++ b/vllm/prompt_adapter/utils.py @@ -8,13 +8,15 @@ from huggingface_hub.utils import EntryNotFoundError from safetensors.torch import load_file as safe_load_file +from vllm.platforms import current_platform + WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" # Get current device name based on available devices def infer_device() -> str: - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): return "cuda" return "cpu" diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 515e0a4d8abe7..7fadfd5dfffb4 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -17,6 +17,7 @@ import vllm.envs as envs from vllm.connections import global_http_connection +from vllm.platforms import current_platform from vllm.version import __version__ as VLLM_VERSION _config_home = envs.VLLM_CONFIG_ROOT @@ -151,7 +152,7 @@ def _report_usage_once(self, model_architecture: str, usage_context: UsageContext, extra_kvs: Dict[str, Any]) -> None: # Platform information - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): device_property = torch.cuda.get_device_properties(0) self.gpu_count = torch.cuda.device_count() self.gpu_type = device_property.name diff --git a/vllm/utils.py b/vllm/utils.py index 29b8a8c2907eb..060b387ec7834 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,6 +5,7 @@ import enum import gc import os +import random import socket import subprocess import sys @@ -32,6 +33,7 @@ import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -373,6 +375,22 @@ def get_cpu_memory() -> int: return psutil.virtual_memory().total +def seed_everything(seed: int) -> None: + """ + Set the seed of each random module. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + + if current_platform.is_cuda_alike(): + torch.cuda.manual_seed_all(seed) + + if is_xpu(): + torch.xpu.manual_seed_all(seed) + + def random_uuid() -> str: return str(uuid.uuid4().hex) @@ -634,9 +652,7 @@ def create_kv_caches_with_random_flash( seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) @@ -678,9 +694,7 @@ def create_kv_caches_with_random( f"Does not support key cache of type fp8 with head_size {head_size}" ) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) @@ -750,7 +764,7 @@ def __init__(self, device: Optional[torch.types.Device] = None): def current_memory_usage(self) -> float: # Return the memory usage in bytes. - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): torch.cuda.reset_peak_memory_stats(self.device) mem = torch.cuda.max_memory_allocated(self.device) elif is_xpu(): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 52092dc2dc291..3851843afc960 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -454,14 +454,20 @@ def init_worker_distributed_environment( def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: - compute_capability = current_platform.get_device_capability() - if compute_capability[0] < 8: + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + raise ValueError( "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU has compute capability " - f"{compute_capability[0]}.{compute_capability[1]}. " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " "You can use float16 instead by explicitly setting the" "`dtype` flag in CLI, for example: --dtype=half.") From 9d104b5beb7bbb51c64b680e007f39169489ea86 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 18 Sep 2024 07:00:56 -0400 Subject: [PATCH 028/116] [CI/Build] Update Ruff version (#8469) Signed-off-by: Aaron Pham Co-authored-by: Cyrus Leung --- .github/workflows/ruff.yml | 4 ++-- benchmarks/kernels/graph_machete_bench.py | 4 +--- format.sh | 4 ++-- pyproject.toml | 2 ++ requirements-lint.txt | 2 +- tests/conftest.py | 5 +---- tests/lora/conftest.py | 5 +---- tests/multimodal/test_base.py | 2 +- tests/test_cache_block_hashing.py | 5 +---- tests/test_logger.py | 4 ++-- tests/worker/test_encoder_decoder_model_runner.py | 4 +--- tests/worker/test_model_runner.py | 4 +--- vllm/adapter_commons/utils.py | 2 +- vllm/attention/backends/utils.py | 6 ++---- vllm/core/block/prefix_caching_block.py | 4 +--- vllm/core/block_manager_v2.py | 4 +--- vllm/engine/async_llm_engine.py | 6 +++--- vllm/engine/llm_engine.py | 6 +++--- .../guided_decoding/outlines_logits_processors.py | 4 ++-- .../layers/quantization/awq_marlin.py | 6 +++--- .../compressed_tensors/compressed_tensors.py | 14 +++++++------- .../layers/quantization/gptq_marlin.py | 8 ++++---- vllm/model_executor/model_loader/tensorizer.py | 4 +--- vllm/model_executor/models/minicpmv.py | 2 +- vllm/spec_decode/draft_model_runner.py | 5 +---- vllm/spec_decode/metrics.py | 7 ++----- vllm/triton_utils/libentry.py | 4 ++-- 27 files changed, 50 insertions(+), 77 deletions(-) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 1a794af572fef..90735d6e2bbf9 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -25,10 +25,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1 isort==5.13.2 + pip install -r requirements-lint.txt - name: Analysing the code with ruff run: | - ruff . + ruff check . - name: Spelling check with codespell run: | codespell --toml pyproject.toml diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index 1d076ed6d5c18..de608fd05af70 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -45,8 +45,7 @@ rows = int(math.ceil(len(results) / 2)) fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) axs = axs.flatten() - axs_idx = 0 - for shape, data in results.items(): + for axs_idx, (shape, data) in enumerate(results.items()): plt.sca(axs[axs_idx]) df = pd.DataFrame(data) sns.lineplot(data=df, @@ -59,6 +58,5 @@ palette="Dark2") plt.title(f"Shape: {shape}") plt.ylabel("time (median, s)") - axs_idx += 1 plt.tight_layout() plt.savefig("graph_machete_bench.pdf") diff --git a/format.sh b/format.sh index 2204b3ba59498..6563d89b192ea 100755 --- a/format.sh +++ b/format.sh @@ -159,7 +159,7 @@ echo 'vLLM codespell: Done' # Lint specified files lint() { - ruff "$@" + ruff check "$@" } # Lint files that differ from main branch. Ignores dirs that are not slated @@ -175,7 +175,7 @@ lint_changed() { if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ - ruff + ruff check fi } diff --git a/pyproject.toml b/pyproject.toml index 6b682f5d4dd4d..14f0934499c46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,8 @@ ignore = [ "E731", # Loop control variable not used within loop body "B007", + # f-string format + "UP032", ] [tool.mypy] diff --git a/requirements-lint.txt b/requirements-lint.txt index d0b2fef6deaef..07f738873e1a8 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -2,7 +2,7 @@ yapf==0.32.0 toml==0.10.2 tomli==2.0.1 -ruff==0.1.5 +ruff==0.6.5 codespell==2.3.0 isort==5.13.2 clang-format==18.1.5 diff --git a/tests/conftest.py b/tests/conftest.py index e4c7b96e82429..e9c7fc7bf9c67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -158,10 +158,7 @@ def should_do_global_cleanup_after_test(request) -> bool: to initialize torch. """ - if request.node.get_closest_marker("skip_global_cleanup"): - return False - - return True + return not request.node.get_closest_marker("skip_global_cleanup") @pytest.fixture(autouse=True) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 0bcae5b0c96dc..4834a9d35a3ee 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -65,10 +65,7 @@ def should_do_global_cleanup_after_test(request) -> bool: to initialize torch. """ - if request.node.get_closest_marker("skip_global_cleanup"): - return False - - return True + return not request.node.get_closest_marker("skip_global_cleanup") @pytest.fixture(autouse=True) diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_base.py index e9562d2048f06..68d05de904ba8 100644 --- a/tests/multimodal/test_base.py +++ b/tests/multimodal/test_base.py @@ -5,7 +5,7 @@ def assert_nested_tensors_equal(expected: NestedTensors, actual: NestedTensors): - assert type(expected) == type(actual) + assert type(expected) == type(actual) # noqa: E721 if isinstance(expected, torch.Tensor): assert torch.equal(expected, actual) else: diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index fe413d1228021..3576a4834ebc3 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -66,8 +66,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, hashes.append([]) prompts = [prefix + prompt for prompt in sample_prompts] - seq_id = 0 - for prompt in prompts: + for seq_id, prompt in enumerate(prompts): hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) seq = Sequence(seq_id, @@ -83,8 +82,6 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, for idx in range(num_blocks): hashes[-1][-1].append(seq.hash_of_block(idx)) - seq_id += 1 - # Check that hashes made with two prefixes with different first blocks are # different everywhere. for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])): diff --git a/tests/test_logger.py b/tests/test_logger.py index 8f3d218416870..fadf66f2b61d4 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -111,7 +111,7 @@ def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist(): configuration occurs.""" with pytest.raises(RuntimeError) as ex_info: _configure_vllm_root_logger() - assert ex_info.type == RuntimeError + assert ex_info.type == RuntimeError # noqa: E721 assert "File does not exist" in str(ex_info) @@ -152,7 +152,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( logging_config_file.name): with pytest.raises(ValueError) as ex_info: _configure_vllm_root_logger() - assert ex_info.type == ValueError + assert ex_info.type == ValueError # noqa: E721 assert "Invalid logging config. Expected Dict, got" in str(ex_info) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index a00d46ddeb007..c0654712b71b5 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -453,8 +453,7 @@ def test_prepare_decode(batch_size): # each sequence) in the decode phase expected_selected_token_indices = [] - selected_token_start_idx = 0 - for seq_len in seq_lens: + for selected_token_start_idx, seq_len in enumerate(seq_lens): # Compute the index offset of the final token in each # sequence's decoded outputs; since a single token is # decoded per iteration per sequence, then the length @@ -463,7 +462,6 @@ def test_prepare_decode(batch_size): # generated tokens is 0 (i.e. the expected sampling index # for a given sequence is just `selected_token_start_idx`) expected_selected_token_indices.append(selected_token_start_idx) - selected_token_start_idx += 1 sampling_metadata = model_input.sampling_metadata actual = sampling_metadata.selected_token_indices diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index a20aa37bcc1e2..42b2337f46914 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -241,10 +241,8 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify Sampling expected_selected_token_indices = [] - selected_token_start_idx = 0 - for _ in context_lens: + for selected_token_start_idx, _ in enumerate(context_lens): expected_selected_token_indices.append(selected_token_start_idx) - selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py index 6c5411f7d3d5c..1e9adca50093b 100644 --- a/vllm/adapter_commons/utils.py +++ b/vllm/adapter_commons/utils.py @@ -42,7 +42,7 @@ def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: def get_adapter(adapter_id: int, registered_adapters: Dict[int, Any]) -> Optional[Any]: - return registered_adapters.get(adapter_id, None) + return registered_adapters.get(adapter_id) ## worker functions diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 089008967a244..49fbb25f4547b 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -33,10 +33,8 @@ def is_block_tables_empty(block_tables: Union[None, Dict]): """ if block_tables is None: return True - if isinstance(block_tables, dict) and all( - value is None for value in block_tables.values()): - return True - return False + return (isinstance(block_tables, dict) + and all(value is None for value in block_tables.values())) def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index a87e814cfb041..db67c95c32429 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -417,9 +417,7 @@ def get_prefix_cache_hit_rate(self) -> float: def is_block_cached(self, block: Block) -> bool: assert block.content_hash is not None - if block.content_hash in self._cached_blocks: - return True - return False + return block.content_hash in self._cached_blocks def promote_to_immutable_block(self, block: Block) -> BlockId: """Once a mutable block is full, it can be promoted to an immutable diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b06385b062e83..54818c7e3e9a6 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -399,9 +399,7 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: """ alloc_status = self._can_swap(seq_group, Device.CPU, SequenceStatus.RUNNING) - if alloc_status == AllocStatus.OK: - return True - return False + return alloc_status == AllocStatus.OK def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: """Returns the block id mapping (from GPU to CPU) generated by diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 410e6ffaa2d50..82cdd41ad497e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -826,7 +826,7 @@ async def generate( request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. trace_headers: OpenTelemetry trace headers. - prompt_adapter_request: Prompt Adapter request to use + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Yields: @@ -1042,7 +1042,7 @@ def remove_logger(self, logger_name: str) -> None: async def start_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes - if type(self.engine.model_executor) == GPUExecutorAsync: + if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 self.engine.model_executor.start_profile() else: self.engine.model_executor._run_workers("start_profile") @@ -1050,7 +1050,7 @@ async def start_profile(self) -> None: async def stop_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes - if type(self.engine.model_executor) == GPUExecutorAsync: + if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 self.engine.model_executor.stop_profile() else: self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8b5009b2c6668..bdf1af014342a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -144,7 +144,7 @@ class LLMEngine: decoding. executor_class: The model executor class for managing distributed execution. - prompt_adapter_config (Optional): The configuration related to serving + prompt_adapter_config (Optional): The configuration related to serving prompt adapters. log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. @@ -1605,7 +1605,7 @@ def check_health(self) -> None: def start_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: + if type(self.model_executor) == GPUExecutor: # noqa: E721 self.model_executor.start_profile() else: self.model_executor._run_workers("start_profile") @@ -1613,7 +1613,7 @@ def start_profile(self) -> None: def stop_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes (MultiprocessingGPUExecutor) - if type(self.model_executor) == GPUExecutor: + if type(self.model_executor) == GPUExecutor: # noqa: E721 self.model_executor.stop_profile() else: self.model_executor._run_workers("stop_profile") diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 554dcc0ed43ed..c28bd71c9f682 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -67,9 +67,9 @@ def __call__(self, input_ids: List[int], instruction = self._guide.get_next_instruction( state=self._fsm_state[seq_id]) - if type(instruction) == Generate: + if type(instruction) == Generate: # noqa: E721 allowed_tokens = instruction.tokens - elif type(instruction) == Write: + elif type(instruction) == Write: # noqa: E721 # TODO: support fast forward tokens allowed_tokens = [instruction.tokens[0]] else: diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eee6a8f7cff49..eed01953fb4af 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -110,9 +110,9 @@ def get_scaled_act_names(self) -> List[str]: def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() - num_bits = quant_config.get("bits", None) - group_size = quant_config.get("group_size", None) - has_zp = quant_config.get("zero_point", None) + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + has_zp = quant_config.get("zero_point") if quant_method != "awq": return False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index ab8207f128348..e536fae45c845 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast import torch from pydantic import BaseModel @@ -79,8 +79,8 @@ def get_quant_method( @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": target_scheme_map: Dict[str, Any] = dict() - ignore: List[str] = config.get("ignore", None) - quant_format: str = config.get("format", None) + ignore = cast(List[str], config.get("ignore")) + quant_format = cast(str, config.get("format")) # The quant_config has multiple config_groups, each containing # an input_activations key with details about how the activations are @@ -200,7 +200,7 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, is_per_tensor_or_channel_weight = (weight_quant.strategy in [ QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL ]) - if not (is_symmetric_weight and is_static_weight + if not (is_symmetric_weight and is_static_weight # noqa: SIM103 and is_per_tensor_or_channel_weight): return False @@ -333,7 +333,7 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): """ - Use the CompressedTensorsScheme associated with each layer to create + Use the CompressedTensorsScheme associated with each layer to create the necessary parameters for the layer. See LinearMethodBase for param details """ @@ -352,8 +352,8 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None): """ - Use the output of create_weights and the CompressedTensorsScheme - associated with the layer to apply the forward pass with the + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the layer input. See LinearMethodBase for param details """ diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index cc699f5b4554f..5a1b2d701ab0d 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -132,10 +132,10 @@ def get_scaled_act_names(self) -> List[str]: def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() - num_bits = quant_config.get("bits", None) - group_size = quant_config.get("group_size", None) - sym = quant_config.get("sym", None) - desc_act = quant_config.get("desc_act", None) + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") if quant_method != "gptq": return False diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 3aac5cd2b43a5..36f33d6d139ee 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -408,9 +408,7 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: "inferred as vLLM models, so setting vllm_tensorized=True is " "only necessary for models serialized prior to this change.") return True - if (".vllm_tensorized_marker" in deserializer): - return True - return False + return ".vllm_tensorized_marker" in deserializer def serialize_vllm_model( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index f8be9490ee55d..f0fc950defed7 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -884,7 +884,7 @@ def __new__( version = str(config.version).split(".") version = tuple([int(x) for x in version]) # Dispatch class based on version - instance_class = _SUPPORT_VERSION.get(version, None) + instance_class = _SUPPORT_VERSION.get(version) if instance_class is None: raise ValueError( "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 1e403637d2388..cf64af72a14a5 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -183,10 +183,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): return False # TODO: Add soft-tuning prompt adapter support - if self.prompt_adapter_config: - return False - - return True + return not self.prompt_adapter_config @torch.inference_mode() def execute_model( diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index ad4e2dc879d7b..89ccaba70e93c 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -104,13 +104,10 @@ def _should_collect_rejsample_metrics(self, now: float) -> bool: if self._rank != 0: return False - if (now - self._last_metrics_collect_time < - self._rejsample_metrics_collect_interval_s): - return False - return True + return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501 def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: - """Copy rejection/typical-acceptance sampling metrics + """Copy rejection/typical-acceptance sampling metrics (number of accepted tokens, etc) to CPU asynchronously. Returns a CUDA event recording when the copy is complete. diff --git a/vllm/triton_utils/libentry.py b/vllm/triton_utils/libentry.py index ae00af44a048a..4335c7adfc13b 100644 --- a/vllm/triton_utils/libentry.py +++ b/vllm/triton_utils/libentry.py @@ -35,8 +35,8 @@ def key(self, spec_args, dns_args, const_args): dns_key = [ arg.dtype if hasattr( arg, "data_ptr") else type(arg) if not isinstance(arg, int) - else "i32" if -(2**31) <= arg and arg <= 2**31 - - 1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64" + else "i32" if arg >= -(2**31) and arg <= 2**31 - + 1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64" for arg in dns_args ] # const args passed by position From 7c7714d856eee6fa94aade729b67f00584f72a4c Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Wed, 18 Sep 2024 09:56:58 -0400 Subject: [PATCH 029/116] [Core][Bugfix][Perf] Introduce `MQLLMEngine` to avoid `asyncio` OH (#8157) Co-authored-by: Nick Hill Co-authored-by: rshaw@neuralmagic.com Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Co-authored-by: Simon Mo --- .buildkite/test-pipeline.yaml | 4 +- docs/source/dev/profiling/profiling_index.rst | 4 +- tests/async_engine/test_openapi_server.py | 106 ---- .../entrypoints/openai/rpc/test_zmq_client.py | 120 ----- tests/entrypoints/openai/test_accuracy.py | 56 +-- .../openai}/test_chat_template.py | 2 +- .../entrypoints/openai/test_mp_api_server.py | 40 -- tests/entrypoints/openai/test_serving_chat.py | 5 +- .../entrypoints/openai/test_serving_engine.py | 4 +- tests/entrypoints/openai/test_shutdown.py | 2 +- .../openai/rpc => mq_llm_engine}/__init__.py | 0 tests/mq_llm_engine/test_abort.py | 67 +++ tests/mq_llm_engine/test_error_handling.py | 244 ++++++++++ tests/mq_llm_engine/test_load.py | 57 +++ tests/mq_llm_engine/utils.py | 78 +++ tests/tpu/test_custom_dispatcher.py | 7 + tests/utils.py | 2 +- vllm/engine/async_llm_engine.py | 9 +- vllm/engine/llm_engine.py | 1 + vllm/engine/multiprocessing/__init__.py | 73 +++ vllm/engine/multiprocessing/client.py | 452 ++++++++++++++++++ vllm/engine/multiprocessing/engine.py | 321 +++++++++++++ vllm/engine/protocol.py | 8 +- vllm/entrypoints/launcher.py | 30 +- vllm/entrypoints/openai/api_server.py | 121 +++-- vllm/entrypoints/openai/rpc/__init__.py | 50 -- vllm/entrypoints/openai/rpc/client.py | 451 ----------------- vllm/entrypoints/openai/rpc/server.py | 243 ---------- vllm/entrypoints/openai/serving_chat.py | 21 +- vllm/entrypoints/openai/serving_completion.py | 21 +- vllm/entrypoints/openai/serving_embedding.py | 11 +- vllm/entrypoints/openai/serving_engine.py | 8 +- .../openai/serving_tokenization.py | 10 +- vllm/envs.py | 6 +- vllm/executor/cpu_executor.py | 1 + vllm/executor/multiproc_worker_utils.py | 4 + 36 files changed, 1467 insertions(+), 1172 deletions(-) delete mode 100644 tests/async_engine/test_openapi_server.py delete mode 100644 tests/entrypoints/openai/rpc/test_zmq_client.py rename tests/{async_engine => entrypoints/openai}/test_chat_template.py (99%) delete mode 100644 tests/entrypoints/openai/test_mp_api_server.py rename tests/{entrypoints/openai/rpc => mq_llm_engine}/__init__.py (100%) create mode 100644 tests/mq_llm_engine/test_abort.py create mode 100644 tests/mq_llm_engine/test_error_handling.py create mode 100644 tests/mq_llm_engine/test_load.py create mode 100644 tests/mq_llm_engine/utils.py create mode 100644 vllm/engine/multiprocessing/__init__.py create mode 100644 vllm/engine/multiprocessing/client.py create mode 100644 vllm/engine/multiprocessing/engine.py delete mode 100644 vllm/entrypoints/openai/rpc/__init__.py delete mode 100644 vllm/entrypoints/openai/rpc/client.py delete mode 100644 vllm/entrypoints/openai/rpc/server.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 63ce9bff7d4c1..37207b677a1ee 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -43,13 +43,15 @@ steps: fast_check: true source_file_dependencies: - vllm/ + - tests/mq_llm_engine - tests/async_engine - tests/test_inputs - tests/multimodal - tests/test_utils - tests/worker commands: - - pytest -v -s async_engine # Async Engine + - pytest -v -s mq_llm_engine # MQLLMEngine + - pytest -v -s async_engine # AsyncLLMEngine - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py - pytest -v -s multimodal diff --git a/docs/source/dev/profiling/profiling_index.rst b/docs/source/dev/profiling/profiling_index.rst index e22d547293445..9e8b2f1817567 100644 --- a/docs/source/dev/profiling/profiling_index.rst +++ b/docs/source/dev/profiling/profiling_index.rst @@ -21,8 +21,8 @@ Traces can be visualized using https://ui.perfetto.dev/. .. tip:: To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100. - Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes. - ``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000`` + Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes. + ``export VLLM_RPC_TIMEOUT=1800000`` Example commands and usage: =========================== diff --git a/tests/async_engine/test_openapi_server.py b/tests/async_engine/test_openapi_server.py deleted file mode 100644 index 9e5c7c04287eb..0000000000000 --- a/tests/async_engine/test_openapi_server.py +++ /dev/null @@ -1,106 +0,0 @@ -import openai # use the official client for correctness check -import pytest -import pytest_asyncio - -from ..utils import VLLM_PATH, RemoteOpenAIServer - -# any model with a chat template should work here -MODEL_NAME = "facebook/opt-125m" -chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" -assert chatml_jinja_path.exists() - - -@pytest.fixture(scope="module") -def server(): - args = [ - # use half precision for speed and memory savings in CI environment - "--dtype", - "float16", - "--max-model-len", - "2048", - "--enforce-eager", - "--chat-template", - str(chatml_jinja_path), - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest_asyncio.fixture -async def client(server): - async with server.get_async_client() as async_client: - yield async_client - - -@pytest.mark.asyncio -async def test_check_models(client: openai.AsyncOpenAI): - models = await client.models.list() - models = models.data - served_model = models[0] - assert served_model.id == MODEL_NAME - assert all(model.root == MODEL_NAME for model in models) - - -@pytest.mark.asyncio -async def test_single_completion(client: openai.AsyncOpenAI): - completion = await client.completions.create(model=MODEL_NAME, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert len(completion.choices) == 1 - assert len(completion.choices[0].text) >= 5 - assert completion.choices[0].finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) - - # test using token IDs - completion = await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 5 - - -@pytest.mark.asyncio -async def test_single_chat_session(client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] - - # test single completion - chat_completion = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_tokens=10, - logprobs=True, - top_logprobs=5) - assert chat_completion.id is not None - assert len(chat_completion.choices) == 1 - - choice = chat_completion.choices[0] - assert choice.finish_reason == "length" - assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=55, total_tokens=65) - - message = choice.message - assert message.content is not None and len(message.content) >= 10 - assert message.role == "assistant" - messages.append({"role": "assistant", "content": message.content}) - - # test multi-turn dialogue - messages.append({"role": "user", "content": "express your result in json"}) - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_tokens=10, - ) - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 0 diff --git a/tests/entrypoints/openai/rpc/test_zmq_client.py b/tests/entrypoints/openai/rpc/test_zmq_client.py deleted file mode 100644 index cafd125c5a598..0000000000000 --- a/tests/entrypoints/openai/rpc/test_zmq_client.py +++ /dev/null @@ -1,120 +0,0 @@ -import asyncio -import tempfile -import unittest -import unittest.mock -import uuid - -import pytest -import pytest_asyncio - -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient, - RPCClientClosedError) -from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -@pytest_asyncio.fixture(scope="function") -async def dummy_server(tmp_socket, monkeypatch): - dummy_engine = unittest.mock.AsyncMock() - - def dummy_engine_builder(*args, **kwargs): - return dummy_engine - - with monkeypatch.context() as m: - m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder) - server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket) - - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - try: - yield server - finally: - server_task.cancel() - server.cleanup() - - -@pytest_asyncio.fixture(scope="function") -async def client(tmp_socket): - client = AsyncEngineRPCClient(rpc_path=tmp_socket) - # Sanity check: the server is connected - await client._wait_for_server_rpc() - - try: - yield client - finally: - client.close() - - -@pytest.mark.asyncio -async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Make the server _not_ reply with a model config - m.setattr(dummy_server, "get_config", lambda x: None) - m.setattr(client, "_data_timeout", 10) - - # And ensure the task completes anyway - # (client.setup() invokes server.get_config()) - client_task = asyncio.get_running_loop().create_task(client.setup()) - with pytest.raises(TimeoutError, match="Server didn't reply within"): - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_aborts_use_timeouts(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Hang all abort requests - m.setattr(dummy_server, "abort", lambda x: None) - m.setattr(client, "_data_timeout", 10) - - # The client should suppress timeouts on `abort`s - # and return normally, assuming the server will eventually - # abort the request. - client_task = asyncio.get_running_loop().create_task( - client.abort("test request id")) - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_data_methods_reraise_exceptions( - monkeypatch, dummy_server, client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Make the server raise some random exception - exception = RuntimeError("Client test exception") - - def raiser(): - raise exception - - m.setattr(dummy_server.engine, "get_model_config", raiser) - m.setattr(client, "_data_timeout", 10) - - client_task = asyncio.get_running_loop().create_task(client.setup()) - # And ensure the task completes, raising the exception - with pytest.raises(RuntimeError, match=str(exception)): - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_errors_after_closing(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - - client.close() - - # Healthchecks and generate requests will fail with explicit errors - with pytest.raises(RPCClientClosedError): - await client.check_health() - with pytest.raises(RPCClientClosedError): - async for _ in client.generate(None, None, None): - pass - - # But no-ops like aborting will pass - await client.abort("test-request-id") - await client.do_log_stats() diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index b442a903c33ae..2ad8460023c25 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -18,38 +18,32 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 +DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] +MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] -@pytest.fixture(scope="module") -def server(): - args = [ - "--max-model-len", "4096", "--enable-chunked-prefill", - "--disable-log-requests", "--enforce-eager" - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="module") -def server_data(server): - return { - "url": f"{server.url_for('v1')}/completions", - } +@pytest.mark.parametrize("more_args", MORE_ARGS_LIST) +def test_lm_eval_accuracy(more_args): + args = list(DEFAULT_ARGS) + args.extend(more_args) + print(f"Running with: {args}") -def test_lm_eval_accuracy(server_data): - model_args = (f"model={MODEL_NAME}," - f"base_url={server_data['url']}," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") - - results = lm_eval.simple_evaluate( - model="local-completions", - model_args=model_args, - tasks=TASK, - ) - - measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + url = f"{remote_server.url_for('v1')}/completions" + + model_args = ( + f"model={MODEL_NAME}," + f"base_url={url}," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/async_engine/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py similarity index 99% rename from tests/async_engine/test_chat_template.py rename to tests/entrypoints/openai/test_chat_template.py index 61a6d77cd8756..b98ab2e30d78d 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -5,7 +5,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer -from ..utils import VLLM_PATH +from ...utils import VLLM_PATH chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists() diff --git a/tests/entrypoints/openai/test_mp_api_server.py b/tests/entrypoints/openai/test_mp_api_server.py deleted file mode 100644 index fbfe0db19dd03..0000000000000 --- a/tests/entrypoints/openai/test_mp_api_server.py +++ /dev/null @@ -1,40 +0,0 @@ -import time - -import pytest - -from vllm.entrypoints.openai.api_server import build_async_engine_client -from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.utils import FlexibleArgumentParser - - -@pytest.mark.asyncio -async def test_mp_crash_detection(): - - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - # use an invalid tensor_parallel_size to trigger the - # error in the server - args.tensor_parallel_size = 65536 - - start = time.perf_counter() - async with build_async_engine_client(args): - pass - end = time.perf_counter() - - assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " - "if there is an error in the startup.") - - -@pytest.mark.asyncio -async def test_mp_cuda_init(): - # it should not crash, when cuda is initialized - # in the API server process - import torch - torch.cuda.init() - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - - async with build_async_engine_client(args): - pass diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index c3a6c65be1d90..de2a932199a01 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock from vllm.config import MultiModalConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.transformers_utils.tokenizer import get_tokenizer @@ -52,8 +52,9 @@ def test_async_serving_chat_init(): def test_serving_chat_should_set_correct_max_tokens(): - mock_engine = MagicMock(spec=AsyncLLMEngine) + mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False serving_chat = OpenAIServingChat(mock_engine, MockModelConfig(), diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 325bc03434287..6d9e620b4af7d 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -4,7 +4,7 @@ import pytest from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) @@ -18,7 +18,7 @@ async def _async_serving_engine_init(): - mock_engine_client = MagicMock(spec=AsyncEngineClient) + mock_engine_client = MagicMock(spec=EngineClient) mock_model_config = MagicMock(spec=ModelConfig) # Set the max_model_len attribute to avoid missing attribute mock_model_config.max_model_len = 2048 diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index 73ecb74007272..25ab91ef69333 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -44,5 +44,5 @@ async def test_shutdown_on_engine_failure(tmp_path): prompt="Hello, my name is") # Now the server should shut down - return_code = remote_server.proc.wait(timeout=3) + return_code = remote_server.proc.wait(timeout=8) assert return_code is not None diff --git a/tests/entrypoints/openai/rpc/__init__.py b/tests/mq_llm_engine/__init__.py similarity index 100% rename from tests/entrypoints/openai/rpc/__init__.py rename to tests/mq_llm_engine/__init__.py diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py new file mode 100644 index 0000000000000..782b508a57149 --- /dev/null +++ b/tests/mq_llm_engine/test_abort.py @@ -0,0 +1,67 @@ +"""Test that aborting is handled properly.""" + +import asyncio +import tempfile +import uuid + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate +from vllm.engine.arg_utils import AsyncEngineArgs + +MODEL = "google/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +RAISED_ERROR = KeyError +RAISED_VALUE = "foo" +EXPECTED_TOKENS = 250 + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +@pytest.mark.asyncio +async def test_abort(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + request_id_to_be_aborted = "request-aborted" + request_ids_a = [f"request-a-{idx}" for idx in range(10)] + request_ids_b = [f"request-b-{idx}" for idx in range(10)] + + # Requests started before one to be aborted. + tasks = [] + for request_id in request_ids_a: + tasks.append( + asyncio.create_task( + generate(client, request_id, EXPECTED_TOKENS))) + + # Aborted. + task_aborted = asyncio.create_task( + generate(client, request_id_to_be_aborted, EXPECTED_TOKENS)) + + # Requests started after one to be aborted. + for request_id in request_ids_b: + tasks.append( + asyncio.create_task( + generate(client, request_id, EXPECTED_TOKENS))) + + # Actually abort. + await asyncio.sleep(0.5) + await client.abort(request_id_to_be_aborted) + + # Confirm that we got all the EXPECTED tokens from the requests. + for task in tasks: + count, request_id = await task + assert count == EXPECTED_TOKENS, ( + f"{request_id} generated only {count} tokens") + + # Cancel task (this will hang indefinitely if not). + task_aborted.cancel() + + # Shutdown. + client.close() diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py new file mode 100644 index 0000000000000..49cfc5aa04c36 --- /dev/null +++ b/tests/mq_llm_engine/test_error_handling.py @@ -0,0 +1,244 @@ +"""Test that various errors are handled properly.""" + +import asyncio +import tempfile +import time +import uuid +from unittest.mock import Mock + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.engine.multiprocessing.engine import MQLLMEngine +from vllm.entrypoints.openai.api_server import build_async_engine_client +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.lora.request import LoRARequest +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser + +MODEL = "google/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +RAISED_ERROR = KeyError +RAISED_VALUE = "foo" + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during first forward pass. + engine.engine.model_executor.execute_model = Mock( + side_effect=RAISED_ERROR(RAISED_VALUE)) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_evil_forward(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_forward) as engine: + + client = await engine.make_client() + + # Server should be healthy after initial probe. + await asyncio.sleep(2.0) + await client.check_health() + + # Throws an error in first forward pass. + with pytest.raises(RAISED_ERROR): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + assert client.errored + + # Engine is errored, should get ENGINE_DEAD_ERROR. + with pytest.raises(MQEngineDeadError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + assert client.errored + + await asyncio.sleep(1.0) + with pytest.raises(RAISED_ERROR): + await client.check_health() + assert client.errored + + # Shutdown. + client.close() + + +def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs, + ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during first forward pass. + engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_failed_health_check(tmp_socket): + with RemoteMQLLMEngine( + engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_model_executor_health) as engine: + + client = await engine.make_client() + assert client.is_running + + # Health probe should throw RAISED_ERROR. + await asyncio.sleep(15.) + + with pytest.raises(RAISED_ERROR): + await client.check_health() + assert client.errored + + # Generate call should throw ENGINE_DEAD_ERROR + with pytest.raises(MQEngineDeadError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + + client.close() + + +def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during abort call. + engine.engine.abort_request = Mock(side_effect=RAISED_ERROR) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_failed_abort(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_abort) as engine: + + client = await engine.make_client() + assert client.is_running + + # Firsh check health should work. + await client.check_health() + + # Trigger an abort on the client side. + async def bad_abort_after_2s(): + await asyncio.sleep(2.0) + await client.abort(request_id="foo") + + # Trigger an abort in 2s from now. + abort_task = asyncio.create_task(bad_abort_after_2s()) + + # Exception in abort() will happen during this generation. + # This will kill the engine and should return ENGINE_DEAD_ERROR + # with reference to the original KeyError("foo") + with pytest.raises(MQEngineDeadError) as execinfo: + async for _ in client.generate( + inputs="Hello my name is", + sampling_params=SamplingParams(max_tokens=2000), + request_id=uuid.uuid4()): + pass + assert "KeyError" in repr(execinfo.value) + assert client.errored + + await abort_task + + # This should raise the original error. + with pytest.raises(RAISED_ERROR): + await client.check_health() + + client.close() + + +@pytest.mark.asyncio +async def test_bad_request(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + # Invalid request should fail, but not crash the server. + with pytest.raises(ValueError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-1", + lora_request=LoRARequest( + "invalid-lora", 1, + "invalid-path")): + pass + + # This request should be okay. + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-2"): + pass + + # Shutdown. + client.close() + + +@pytest.mark.asyncio +async def test_mp_crash_detection(monkeypatch): + + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + + # When LLMEngine is loaded, it will crash. + def mock_init(): + raise ValueError + + monkeypatch.setattr(LLMEngine, "__init__", mock_init) + + start = time.perf_counter() + async with build_async_engine_client(args): + pass + end = time.perf_counter() + + assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " + "if there is an error in the startup.") + + +@pytest.mark.asyncio +async def test_mp_cuda_init(): + # it should not crash, when cuda is initialized + # in the API server process + import torch + torch.cuda.init() + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + + async with build_async_engine_client(args): + pass diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py new file mode 100644 index 0000000000000..630c112d0f0c9 --- /dev/null +++ b/tests/mq_llm_engine/test_load.py @@ -0,0 +1,57 @@ +"""Test that the MQLLMEngine is able to handle 10k concurrent requests.""" + +import asyncio +import tempfile +import uuid + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate +from vllm.engine.arg_utils import AsyncEngineArgs + +MODEL = "google/gemma-1.1-2b-it" +NUM_EXPECTED_TOKENS = 10 +NUM_REQUESTS = 10000 + +# Scenarios to test for num generated token. +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True) + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +@pytest.mark.asyncio +async def test_load(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks = [] + for request_id in request_ids: + tasks.append( + asyncio.create_task( + generate(client, request_id, NUM_EXPECTED_TOKENS))) + + # Confirm that we got all the EXPECTED tokens from the requests. + failed_request_id = None + tokens = None + for task in tasks: + num_generated_tokens, request_id = await task + if (num_generated_tokens != NUM_EXPECTED_TOKENS + and failed_request_id is None): + failed_request_id = request_id + tokens = num_generated_tokens + + assert failed_request_id is None, ( + f"{failed_request_id} generated {tokens} but " + f"expected {NUM_EXPECTED_TOKENS}") + + # Shutdown. + client.close() diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py new file mode 100644 index 0000000000000..e27fd77923412 --- /dev/null +++ b/tests/mq_llm_engine/utils.py @@ -0,0 +1,78 @@ +import asyncio +import multiprocessing +from typing import Callable, Tuple, Union + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import MQLLMEngine +from vllm.outputs import RequestOutput +from vllm.usage.usage_lib import UsageContext + + +async def generate( + client: MQLLMEngineClient, + request_id: str, + num_tokens: int, + return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]: + + final_output = None + count = 0 + async for out in client.generate( + request_id=request_id, + inputs="Hello my name is Robert and", + sampling_params=SamplingParams(max_tokens=num_tokens, + temperature=0)): + + count += 1 + final_output = out + await asyncio.sleep(0.) + + if return_output: + return final_output + + # Confirm we generated all the tokens we expected. + return count, request_id + + +def run_normal(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Run engine. + engine.start() + + +class RemoteMQLLMEngine: + + def __init__(self, + engine_args: AsyncEngineArgs, + ipc_path: str, + run_fn: Callable = run_normal) -> None: + + self.engine_args = engine_args + self.ipc_path = ipc_path + context = multiprocessing.get_context("spawn") + self.proc = context.Process(target=run_fn, + args=(engine_args, ipc_path)) + self.proc.start() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.kill() + + async def make_client(self) -> MQLLMEngineClient: + engine_config = self.engine_args.create_engine_config() + client = MQLLMEngineClient(self.ipc_path, engine_config) + while True: + try: + await client.setup() + break + except TimeoutError: + assert self.proc.is_alive() + return client diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 7f3fb595321ad..69ab67abdd12b 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,5 +1,12 @@ +import os + from ..utils import compare_two_settings +# --enforce-eager on TPU causes graph compilation +# this times out default Health Check in the MQLLMEngine, +# so we set the timeout here to 30s +os.environ["VLLM_RPC_TIMEOUT"] = "30000" + def test_custom_dispatcher(): compare_two_settings("google/gemma-2b", diff --git a/tests/utils.py b/tests/utils.py index f6c2be17ebdcf..81442cad78da2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -119,7 +119,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.proc.terminate() try: - self.proc.wait(3) + self.proc.wait(8) except subprocess.TimeoutExpired: # force kill if needed self.proc.kill() diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 82cdd41ad497e..34e7e05341f02 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -601,9 +601,12 @@ def errored(self) -> bool: return self._errored_with is not None @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" - return None + def dead_error(self) -> BaseException: + return AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") def set_errored(self, exc: Exception) -> None: self._errored_with = exc diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bdf1af014342a..2743d5c7d2282 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1289,6 +1289,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # torch.distributed ops which may otherwise timeout, and unblocks # the RPC thread in the workers so that they can process any other # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") self.model_executor.stop_remote_worker_execution_loop() return ctx.request_outputs diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py new file mode 100644 index 0000000000000..ba5c6e15fc821 --- /dev/null +++ b/vllm/engine/multiprocessing/__init__.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Mapping, Optional, Union + +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + +VLLM_RPC_SUCCESS_STR = "SUCCESS" + +IPC_INPUT_EXT = "_input_socket" +IPC_OUTPUT_EXT = "_output_socket" +IPC_HEALTH_EXT = "_health_socket" +IPC_DATA_EXT = "_data_socket" + + +class MQEngineDeadError(RuntimeError): + pass + + +@dataclass +class RPCGenerateRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + +@dataclass +class RPCError: + request_id: Optional[str] + is_engine_errored: bool + exception: BaseException + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCHealthRequest: + pass + + +class RPCStartupRequest(Enum): + IS_SERVER_READY = 1 + + +@dataclass +class RPCStartupResponse: + tracing_enabled: bool + + +RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest, + RPCStartupRequest] + +REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] + + +def ENGINE_DEAD_ERROR( + error: Optional[BaseException] = None) -> MQEngineDeadError: + if error is None: + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + "find the original error") + + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + f"find the original error: {repr(error)}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py new file mode 100644 index 0000000000000..18b620c74ddf9 --- /dev/null +++ b/vllm/engine/multiprocessing/client.py @@ -0,0 +1,452 @@ +import asyncio +import copy +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, + Union) + +import cloudpickle +import zmq +import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCGenerateRequest, + RPCHealthRequest, RPCStartupRequest, + RPCStartupResponse) +# yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.inputs import PromptInputs +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +logger = init_logger(__name__) + + +class MQClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + +class MQLLMEngineClient: + """A client wrapper for MQLLMEngine that conforms to the + EngineClient protocol. + + MQLLMEngine and MQLLMEngineClient are intended to run in separate + processes communicating via zeromq ipc sockets. + + The entrypoint to MQLLMEngineClient is through the generate() + method. On generate() MQLLMEngine does three things: + - Creates an asyncio output queue + - Sends a RPCGenerateRequest to the MQLLMEngine via zmq + - Pulls RequestOutputs from its queue and yields them + + MQLLMEngine runs two background loops: + - output_loop: the output loop pulls List[RequestOutput] + from the MQLLMEngine via zmq (each list is the output + of one engine_step in the LLMEngine). It then parses + the list and pushes individual request_outputs into + the corresponding output_queue such that they can be + consumed by the .generate() method. + - health_loop: the health loop queries the health socket + every N seconds, confirming the engine is healthy + """ + + def __init__(self, ipc_path: str, engine_config: EngineConfig): + self.context = zmq.asyncio.Context() + self._errored_with: Optional[BaseException] = None + + # Get the configs. + self.model_config = engine_config.model_config + self.decoding_config = engine_config.decoding_config + + # Create the tokenizer group. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=engine_config.scheduler_config, + parallel_config=engine_config.parallel_config, + enable_lora=bool(engine_config.lora_config), + ) + + # Send RPCGenerateRequest to the MQLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") + + # Receive streams of RequestOutput from the MQLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # IPC path for ack of check_health requests. + self.health_socket: Socket = self.context.socket(zmq.constants.PULL) + self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + + # Loop to check health of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.health_loop: Optional[asyncio.Task] = None + + @staticmethod + def is_unsupported_config(engine_args: AsyncEngineArgs): + if engine_args.pipeline_parallel_size > 1: + return True + + is_embedding = ModelConfig( + model=engine_args.model, + revision=engine_args.revision, + tokenizer=engine_args.model, + tokenizer_mode="auto", + trust_remote_code=engine_args.trust_remote_code, + quantization=engine_args.quantization, + seed=0, + dtype="auto").embedding_mode + + return is_embedding + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_check_health_loop(self, timeout: int): + """Background loop that continually probes the RPCServer for health. + + The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which + the MQLLMEngine server is blocking on. + + The Server replies on the HEALTH_SOCKET (rather than on the + OUTPUT_SOCKET such that the messages are not intermingled with + output streaming). + """ + + try: + while True: + if await self.health_socket.poll(timeout=timeout) == 0: + # Wakeup every N seconds and do a health probe. + await self._send_one_way_rpc_request( + RPCHealthRequest(), self.input_socket) + + # Wait for ack from the health socket. + await self._await_ack(error_message="Health check failed.", + socket=self.health_socket) + else: + # Server sent a health status message unprompted. + await self._check_success( + error_message="Health check failed.", + socket=self.health_socket) + + logger.debug("Health probe successful.") + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check health loop.") + + except Exception as e: + self._set_errored(e) + + async def run_output_handler_loop(self): + """Get RequestOutputs from Engine and stream to Request Queues""" + + try: + while True: + # Poll, checking for ENGINE_DEAD + while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT + ) == 0: + logger.debug("Waiting for output from MQLLMEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) + return + + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + is_error = isinstance(request_outputs, + (BaseException, RPCError)) + if is_error: + if isinstance(request_outputs, RPCError): + rpc_error: RPCError = request_outputs + request_id = rpc_error.request_id + exception = rpc_error.exception + is_engine_errored = rpc_error.is_engine_errored + else: + # MPLLMEngine should always return an RPCError to + # the output_socket when an issue arises. + # If we are here, we are in a bad state and + # should shut down the server. + error: BaseException = request_outputs + logger.error( + "Received Exception %s rather than RPCError from " + "MPLLMEngine. This should never happen.", error) + request_id = None + exception = error + is_engine_errored = True + + # Set to error state only on engine critical error + # (and record only the first one) + if is_engine_errored and not self._errored_with: + self._errored_with = exception + + if request_id is None: + for queue_i in tuple(self.output_queues.values()): + queue_i.put_nowait(exception) + else: + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(exception) + else: + # Put each output into the appropriate steam. + for request_output in request_outputs: + queue = self.output_queues.get( + request_output.request_id) + if queue is not None: + queue.put_nowait(request_output) + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient output handler.") + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + with self.get_data_socket() as socket: + # Wait until server is ready. + response = await self._wait_for_server_rpc(socket) + + self.tracing_flag = response.tracing_enabled + + # Start health_loop. + self.health_loop = asyncio.create_task( + self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets and terminate the context. + self.context.destroy(linger=0) + + # Cancel background tasks. + if self.health_loop is not None: + self.health_loop.cancel() + self.output_loop.cancel() + + def _set_errored(self, e: BaseException): + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + @staticmethod + async def _send_get_data_rpc_request(request: RPCStartupRequest, + expected_type: Any, + error_message: str, + socket: Socket) -> Any: + """Send an RPC request that is expecting data back.""" + + # Ping RPCServer with a request. + await socket.send_multipart((pickle.dumps(request), ), copy=False) + + # Make sure the server responds in time. + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("RPCServer didn't reply within " + f"{VLLM_RPC_TIMEOUT} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, BaseException): + raise data + elif not isinstance(data, expected_type): + raise ValueError(error_message) + + return data + + @staticmethod + async def _send_one_way_rpc_request(request: RPC_REQUEST_T, + socket: Socket): + """Send one-way RPC request to trigger an action.""" + + if socket.closed: + raise MQClientClosedError() + + await socket.send_multipart((pickle.dumps(request), )) + + async def _await_ack(self, error_message: str, socket: Socket): + """Await acknowledgement that a request succeeded.""" + + if socket.closed: + raise MQClientClosedError() + + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("MQLLMEngine didn't reply within " + f"{VLLM_RPC_TIMEOUT}ms") + + await self._check_success(error_message, socket) + + @staticmethod + async def _check_success(error_message: str, socket: Socket): + """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" + + if socket.closed: + raise MQClientClosedError() + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) + + # Raise error if unsuccessful + if isinstance(response, BaseException): + raise response + elif (not isinstance(response, str) + or response != VLLM_RPC_SUCCESS_STR): + raise ValueError(error_message) + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: + """Wait for the RPCServer to start up.""" + + return await self._send_get_data_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + expected_type=RPCStartupResponse, + error_message="Unable to start RPC Server", + socket=socket) + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + with suppress(MQClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), socket=self.input_socket) + + async def do_log_stats(self): + """Ignore do_log_stats (handled on MQLLMEngine polling)""" + pass + + async def check_health(self): + """ + The check health loop probes the health status of the + Engine's health every N seconds and sets _errored_with + if the engine is unhealthy. + """ + if self._errored_with is not None: + raise self._errored_with + + @property + def is_running(self) -> bool: + return not self.errored + + @property + def is_stopped(self) -> bool: + return self.errored + + @property + def errored(self) -> bool: + return self._errored_with is not None + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncGenerator[RequestOutput, None]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + # If already dead, error out. + if self._errored_with is not None: + raise ENGINE_DEAD_ERROR(self._errored_with) + + # 1) Create output queue for this requests. + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + + try: + # 2) Detach logits processors so that they can be pickled + # separately (may require cloudpickle which is slower) + if sampling_params.logits_processors: + # Defensive shallow copy + sampling_params = copy.copy(sampling_params) + logits_processors = sampling_params.logits_processors + sampling_params.logits_processors = None + lp_bytes = cloudpickle.dumps(logits_processors) + else: + lp_bytes = None + + request_bytes = pickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)) + + # 3) Send the RPCGenerateRequest to the MQLLMEngine. + parts = (request_bytes, + lp_bytes) if lp_bytes else (request_bytes, ) + await self.input_socket.send_multipart(parts, copy=False) + + # 4) Stream the RequestOutputs from the output queue. Note + # that the output_loop pushes RequestOutput objects to this + # queue after pulling them from the zmq socket. + finished = False + try: + while not finished: + request_output = await queue.get() + + if isinstance(request_output, BaseException): + raise request_output + + finished = request_output.finished + yield request_output + finally: + # Request was canceled by the client. + if not finished and not self.errored: + await self.abort(request_id) + finally: + self.output_queues.pop(request_id) + + async def encode(self, *args, + **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py new file mode 100644 index 0000000000000..70cd6e5cb6000 --- /dev/null +++ b/vllm/engine/multiprocessing/engine.py @@ -0,0 +1,321 @@ +import pickle +import signal +from contextlib import contextmanager +from typing import Iterator, List, Optional, Union + +import cloudpickle +import zmq + +from vllm import AsyncEngineArgs, LLMEngine +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCGenerateRequest, + RPCHealthRequest, RPCStartupRequest, + RPCStartupResponse) +# yapf: enable +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.usage.usage_lib import UsageContext + +CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, + SchedulerConfig, LoRAConfig] + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 10000 +HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) + + +class MQLLMEngine: + """A multiprocessing wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to enable use + in concurrnet manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally via ipc. + + The :class:`LLMEngine.generate` is kicked off when a new + RPCGenerateRequest is received by the input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + :class:`LLMEngine.step()`, and sends the RequestOutputs back over + the output_socket. + + If use_async_sockets is set, the logic associated with reading new + requests from the socket and sending data to the socket is passed + as a callback to the llm_engine, which calls the logic asynchronously + such that the IPC can be overlapped with the GPU. + + Args: + ipc_path: Base path for zeromq interprocess messaging + use_async_sockets: Whether to make send/recv async with GPU + log_requests: Whether to log the requests. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. + """ + + def __init__(self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs) -> None: + self.engine = LLMEngine(*args, **kwargs) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = \ + self._async_socket_engine_callback + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send health status back to client. + self.health_socket = self.ctx.socket(zmq.constants.PUSH) + self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + + @classmethod + def from_engine_args(cls, engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an MQLLMEngine from the engine arguments.""" + + engine_config = engine_args.create_engine_config() + + executor_class = LLMEngine._get_executor_cls(engine_config) + + return cls( + ipc_path=ipc_path, + use_async_sockets=engine_config.model_config.use_async_output_proc, + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context) + + def start(self): + try: + try: + logger.debug("Starting Startup Loop.") + self.run_startup_loop() + logger.debug("Starting Engine Loop.") + self.run_engine_loop() + except Exception as e: + logger.exception(repr(e)) + except KeyboardInterrupt: + logger.debug("Shutting down MQLLMEngine.") + finally: + logger.debug("MQLLMEngine is shut down.") + self.cleanup() + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + # Closes all sockets and destroys context. + self.ctx.destroy(linger=0) + del self.engine + + @contextmanager + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Startup loop for sending data from Engine -> Client.""" + + with self.make_data_socket() as socket: + response: Union[RPCStartupResponse, BaseException] + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.IS_SERVER_READY: + tracing_enabled = self.engine.is_tracing_enabled() + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) + + except Exception as e: + response = e + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) + + def run_engine_loop(self): + """Core busy loop of the LLMEngine.""" + + while True: + if not self.engine.has_unfinished_requests(): + # Poll until there is work to do. + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + self.engine.do_log_stats() + logger.debug("Waiting for new requests in engine loop.") + + # Handle any input from the client. + self.handle_new_input() + + # Engine step. + request_outputs = self.engine_step() + + # Send request outputs (if async, done in engine_step callback). + if not self.use_async_sockets: + self._send_outputs(request_outputs) + + def engine_step(self) -> List[RequestOutput]: + """Engine step wrapper with error handling.""" + + try: + return self.engine.step() + except SystemExit: + raise + except BaseException as e: + self._set_errored(e) + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) + self._send_outputs(rpc_err) + raise e + + def handle_new_input(self): + """Handle new input from the socket""" + try: + while self.input_socket.poll(timeout=0) != 0: + frames = self.input_socket.recv_multipart(copy=False) + request = pickle.loads(frames[0].buffer) + + if isinstance(request, RPCGenerateRequest): + if len(frames) > 1: + # Use cloudpickle for logits processors + lprocs = cloudpickle.loads(frames[1].buffer) + request.sampling_params.logits_processors = lprocs + self._handle_generate_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCHealthRequest): + self._handle_health_request() + else: + raise ValueError("Unknown RPCRequest Type: {request}") + + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + raise e + + def _handle_generate_request(self, request: RPCGenerateRequest): + """Handle RPCGenerateRequest by adding it to the LLMEngine.""" + request_id = request.request_id + + if self._errored_with is not None: + rpc_err = RPCError(request_id=request_id, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR(self._errored_with)) + self._send_outputs(rpc_err) + + try: + self.engine.add_request( + request_id=request_id, + inputs=request.inputs, + params=request.sampling_params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request) + + if self.log_requests: + logger.info("Added request %s.", request.request_id) + + except Exception as e: + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, + # rather than an issue with the engine itself. + is_errored = self._errored_with is not None + rpc_err = RPCError(request_id=request_id, + is_engine_errored=is_errored, + exception=e) + self._send_outputs(rpc_err) + + # Remove request from the engine. + self.engine.abort_request(request_id) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request(request.request_id) + if self.log_requests: + logger.info("Aborted request %s.", request.request_id) + + def _handle_health_request(self): + if self._errored_with is not None: + self._send_unhealthy(self._errored_with) + + # Raises error if unhealthy. + self.engine.check_health() + self._send_healthy() + + def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): + """Send List of RequestOutput to RPCClient.""" + if outputs: + output_bytes = pickle.dumps(outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) + + def _send_healthy(self): + """Send HEALTHY message to RPCClient.""" + self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + + def _send_unhealthy(self, error: BaseException): + """Send UNHEALTHY message to RPCClient.""" + error_bytes = pickle.dumps(error) + self.health_socket.send_multipart((error_bytes, ), copy=False) + + def _async_socket_engine_callback(self, + request_outputs: REQUEST_OUTPUTS_T): + """Callback used by engine to make socket handling async with GPU.""" + self._send_outputs(request_outputs) + self.handle_new_input() + + def _set_errored(self, e: BaseException): + """Log and set errored status if this is the first issue.""" + if self._errored_with is None: + self._errored_with = e + + +def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, + ipc_path: str): + + def signal_handler(*_) -> None: + # Interrupt server on sigterm + raise KeyboardInterrupt("MQLLMEngine terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + engine = MQLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) + engine.start() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 34ae79f5fa8df..70444faa670a2 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -14,8 +14,8 @@ @runtime_checkable -class AsyncEngineClient(Protocol): - """Protocol class for Clients to AsyncLLMEngine""" +class EngineClient(Protocol): + """Protocol class for Clients to Engine""" @property def is_running(self) -> bool: @@ -30,8 +30,8 @@ def errored(self) -> bool: ... @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" + def dead_error(self) -> BaseException: + ... def generate( self, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 47d227010c075..5dcf50bd1b0a1 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -1,21 +1,21 @@ import asyncio import signal from http import HTTPStatus -from typing import Any, Optional +from typing import Any import uvicorn from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError +from vllm.engine.multiprocessing import MQEngineDeadError from vllm.logger import init_logger from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, limit_concurrency: Optional[int], - **uvicorn_kwargs: Any): +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: methods = getattr(route, "methods", None) @@ -26,15 +26,6 @@ async def serve_http(app: FastAPI, limit_concurrency: Optional[int], logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) - # Set concurrency limits in uvicorn if running in multiprocessing mode - # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). - if limit_concurrency is not None: - logger.info( - "Launching Uvicorn with --limit_concurrency %s. To avoid this " - "limit at the expense of performance run with " - "--disable-frontend-multiprocessing", limit_concurrency) - uvicorn_kwargs["limit_concurrency"] = limit_concurrency - config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) _add_shutdown_handlers(app, server) @@ -63,7 +54,7 @@ async def dummy_shutdown() -> None: logger.debug( "port %s is used by process %s launched with command:\n%s", port, process, " ".join(process.cmdline())) - logger.info("Gracefully stopping http server") + logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() @@ -90,7 +81,7 @@ async def runtime_error_handler(request: Request, __): return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) @app.exception_handler(AsyncEngineDeadError) - async def engine_dead_handler(_, __): + async def async_engine_dead_handler(_, __): """Kill the server if the async engine is already dead. It will not handle any further requests.""" if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: @@ -99,3 +90,14 @@ async def engine_dead_handler(_, __): server.should_exit = True return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(MQEngineDeadError) + async def mq_engine_dead_handler(_, __): + """Kill the server if the mq engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("MQLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b891debfd2b91..1b9eb30252417 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -26,7 +26,9 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import run_mp_engine +from vllm.engine.protocol import EngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -44,8 +46,6 @@ TokenizeRequest, TokenizeResponse, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient -from vllm.entrypoints.openai.rpc.server import run_rpc_server # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -67,29 +67,16 @@ _running_tasks: Set[asyncio.Task] = set() -def model_is_embedding(model_name: str, trust_remote_code: bool, - quantization: Optional[str], - revision: Optional[str]) -> bool: - return ModelConfig(model=model_name, - revision=revision, - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - quantization=quantization, - seed=0, - dtype="auto").embedding_mode - - @asynccontextmanager async def lifespan(app: FastAPI): try: if app.state.log_stats: - async_engine_client = app.state.engine_client + engine_client: EngineClient = app.state.engine_client async def _force_log(): while True: - await asyncio.sleep(10) - await async_engine_client.do_log_stats() + await asyncio.sleep(10.) + await engine_client.do_log_stats() task = asyncio.create_task(_force_log()) _running_tasks.add(task) @@ -108,9 +95,9 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: + args: Namespace) -> AsyncIterator[Optional[EngineClient]]: - # Context manager to handle async_engine_client lifecycle + # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit engine_args = AsyncEngineArgs.from_cli_args(args) @@ -123,19 +110,18 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, -) -> AsyncIterator[Optional[AsyncEngineClient]]: +) -> AsyncIterator[Optional[EngineClient]]: """ - Create AsyncEngineClient, either: + Create EngineClient, either: - in-process using the AsyncLLMEngine Directly - multiprocess using AsyncLLMEngine RPC Returns the Client or None if the creation failed. """ - # If manually triggered or embedding model, use AsyncLLMEngine in process. - # TODO: support embedding model via RPC. - if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, - engine_args.quantization, engine_args.revision) + # Fall back + # TODO: fill out feature matrix. + if (MQLLMEngineClient.is_unsupported_config(engine_args) or disable_frontend_multiprocessing): engine_config = engine_args.create_engine_config() uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), @@ -173,56 +159,60 @@ async def build_async_engine_client_from_engine_args( "and vLLM will properly handle cleanup.") # Select random path for IPC. - rpc_path = get_open_zmq_ipc_path() - logger.info("Multiprocessing frontend to use %s for RPC Path.", - rpc_path) - - # Build RPCClient, which conforms to AsyncEngineClient Protocol. - # NOTE: Actually, this is not true yet. We still need to support - # embedding models via RPC (see TODO above) - rpc_client = AsyncEngineRPCClient(rpc_path) + ipc_path = get_open_zmq_ipc_path() + logger.info("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) - # Start RPCServer in separate process (holds the AsyncLLMEngine). - context = multiprocessing.get_context("spawn") + # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, # so we need to spawn a new process - rpc_server_process = context.Process( - target=run_rpc_server, - args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) - rpc_server_process.start() - logger.info("Started engine process with PID %d", - rpc_server_process.pid) + context = multiprocessing.get_context("spawn") + + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path)) + engine_process.start() + logger.info("Started engine process with PID %d", engine_process.pid) + + # Build RPCClient, which conforms to EngineClient Protocol. + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + engine_config = engine_args.create_engine_config() + mp_engine_client = MQLLMEngineClient(ipc_path, engine_config) try: while True: try: - await rpc_client.setup() + await mp_engine_client.setup() break except TimeoutError: - if not rpc_server_process.is_alive(): - logger.error( - "RPCServer process died before responding " - "to readiness probe") + if not engine_process.is_alive(): + logger.error("Engine process died before responding " + "to readiness probe") yield None return - yield rpc_client # type: ignore[misc] + yield mp_engine_client # type: ignore[misc] finally: # Ensure rpc server process was terminated - rpc_server_process.terminate() + engine_process.terminate() # Close all open connections to the backend - rpc_client.close() + mp_engine_client.close() - # Wait for server process to join - rpc_server_process.join() + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() # Lazy import for prometheus multiprocessing. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable # before prometheus_client is imported. # See https://prometheus.github.io/client_python/multiprocess/ from prometheus_client import multiprocess - multiprocess.mark_process_dead(rpc_server_process.pid) + multiprocess.mark_process_dead(engine_process.pid) router = APIRouter() @@ -270,7 +260,7 @@ def embedding(request: Request) -> OpenAIServingEmbedding: return request.app.state.openai_serving_embedding -def engine_client(request: Request) -> AsyncEngineClient: +def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client @@ -473,7 +463,7 @@ async def authentication(request: Request, call_next): def init_app_state( - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, state: State, args: Namespace, @@ -488,11 +478,11 @@ def init_app_state( else: request_logger = RequestLogger(max_log_len=args.max_log_len) - state.engine_client = async_engine_client + state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.openai_serving_chat = OpenAIServingChat( - async_engine_client, + engine_client, model_config, served_model_names, args.response_role, @@ -504,7 +494,7 @@ def init_app_state( enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser) state.openai_serving_completion = OpenAIServingCompletion( - async_engine_client, + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -513,13 +503,13 @@ def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) state.openai_serving_embedding = OpenAIServingEmbedding( - async_engine_client, + engine_client, model_config, served_model_names, request_logger=request_logger, ) state.openai_serving_tokenization = OpenAIServingTokenization( - async_engine_client, + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -541,21 +531,20 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) - async with build_async_engine_client(args) as async_engine_client: + async with build_async_engine_client(args) as engine_client: # If None, creation of the client failed and we exit. - if async_engine_client is None: + if engine_client is None: return app = build_app(args) - model_config = await async_engine_client.get_model_config() - init_app_state(async_engine_client, model_config, app.state, args) + model_config = await engine_client.get_model_config() + init_app_state(engine_client, model_config, app.state, args) temp_socket.close() shutdown_task = await serve_http( app, - limit_concurrency=async_engine_client.limit_concurrency, host=args.host, port=args.port, log_level=args.uvicorn_log_level, diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py deleted file mode 100644 index efc7e43afdcc9..0000000000000 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import Mapping, Optional, Union - -from vllm.inputs import PromptInputs -from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams - -# Success string used for RPC instructions. -VLLM_RPC_SUCCESS_STR = "SUCCESS" - -# Minimum value of ZMQ.SOCKET_LIMIT to run mp. -VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 - -# HWM is set to Infinity. -VLLM_RPC_ZMQ_HWM = 0 - - -@dataclass -class RPCGenerateRequest: - inputs: PromptInputs - sampling_params: SamplingParams - request_id: str - lora_request: Optional[LoRARequest] = None - trace_headers: Optional[Mapping[str, str]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - - -@dataclass -class RPCAbortRequest: - request_id: str - - -class RPCUtilityRequest(Enum): - IS_SERVER_READY = 1 - GET_MODEL_CONFIG = 2 - GET_DECODING_CONFIG = 3 - GET_PARALLEL_CONFIG = 4 - GET_SCHEDULER_CONFIG = 5 - GET_LORA_CONFIG = 6 - DO_LOG_STATS = 7 - IS_SERVER_HEALTHY = 8 - IS_TRACING_ENABLED = 9 - START_PROFILE = 10 - STOP_PROFILE = 11 - - -RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest] diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py deleted file mode 100644 index 9b88db746be5c..0000000000000 --- a/vllm/entrypoints/openai/rpc/client.py +++ /dev/null @@ -1,451 +0,0 @@ -import asyncio -import pickle -from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Iterator, Mapping, Optional -from uuid import uuid4 - -import cloudpickle -import zmq -import zmq.asyncio -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -# yapf: disable -from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_SOCKET_LIMIT_CUTOFF, - VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) -# yapf: enable -from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS -from vllm.inputs import PromptInputs -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - -logger = init_logger(__name__) - -# Path used for inprocess proxy. -INPROC_PROXY_PATH = f"inproc://{uuid4()}" - - -class RPCClientClosedError(Exception): - """Exception class raised when the client is used post-close. - - The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which - causes a ZMQError and creates a huge stack trace. - So, we throw this error such that we can suppress it. - """ - - -class AsyncEngineRPCClient: - """ - RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. - - The overall design mirrors the Asynchronous Client Server Pattern - https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern - - On startup, the RPCClient: - - makes DEALER socket (to_rpc_server) that connects to the RPCServer - via ipc, which uses unix sockets under the hood - (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html) - - makes ROUTER socket (from_api_server) that binds to a random - inproc address, which uses memory under the hood - (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html) - - runs a proxy in a background asyncio task between - from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, ) - - Each request handled by the asyncio api_server calls generate(): - - make a DEALER socket that connects to from_api_server via inproc - - send a RCPGenerateRequest to the inproc socket - - background proxy forwards the request from inproc -> ipc - - RPCServer responds to the request one token at a time over ipc - - background proxy forwards the response from ipc -> inproc - - The connection looks like this: - DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER - - Message routing is performed via identities that are managed by the - ROUTER socket. ROUTER sockets track every connection it has and - tells the caller about these. The way it tells the caller is to stick - the connection identity in front of each message received. When we - send the message via a ROUTER, we first send an identity frame. - See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope - for more details on connection identities. - - This proxy design enables us to use a single unix socket, which - improves performance by avoiding syscalls (~5%) and avoids resource limits - such as ulimit, which defaults to 1024 on ubuntu. - - Note: we run set_hwm(0) on each socket, which sets the HWM to inf, - which is required to avoid dropping messages under high load. - This is generally not advisable. However, since we are in control - of both sides of the connection + failure on either side is - catastrophic to the overall system health and memory profiling - suggests limited memory overhead relative to asyncio, we will - proceed for now. - - See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks - for more details on high water marks. - """ - - def __init__(self, rpc_path: str): - self.context = zmq.asyncio.Context() - self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS - self._errored = False - - # Maximum number of sockets that can be opened (typically 65536). - # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) - socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) - assert isinstance(socket_limit, int) - if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF: - raise ValueError( - f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " - "the number of concurrent requests vLLM can process. Launch " - "vLLM with --disable-frontend-multiprocessing and open a " - "GitHub issue so we can investigate.") - - # We only have 1 ipc connection that uses unix sockets, so - # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will - # not run into ulimit issues) - self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) - - # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) - self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.to_rpc_server.bind(rpc_path) - - # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server: Socket = self.context.socket( - zmq.constants.ROUTER) - self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.from_api_server.bind(INPROC_PROXY_PATH) - - # Asyncio background task for the proxy. - self.proxy_in_task = asyncio.create_task( - self.run_proxy(self.from_api_server, self.to_rpc_server)) - self.proxy_out_task = asyncio.create_task( - self.run_proxy(self.to_rpc_server, self.from_api_server)) - - # Since we open 1 inproc socket per request, we have a hard cap on - # the number of requests that can run in vLLM w. frontend - # mulitprocessing. This value is used uvicorn to launch - # with --limit-concurrency to return 503 when server is overloaded. - # We need 2 sockets per request - 2: - # 1 for generate(), 1 for abort(), do_log_stats(), check_health() - self.limit_concurrency = socket_limit // 2 - 2 - - async def run_proxy(self, socket_from: Socket, socket_to: Socket): - """Background task that runs a proxy""" - while True: - frames = await socket_from.recv_multipart(copy=False) - await socket_to.send_multipart(frames, copy=False) - - async def setup(self): - """Setup the client before it starts sending server requests.""" - - # Wait until server is ready. - await self._wait_for_server_rpc() - - # Get the configs. - self.model_config = await self._get_model_config_rpc() - self.decoding_config = await self._get_decoding_config_rpc() - self.tracing_flag = await self._is_tracing_enabled_rpc() - - # Create the tokenizer group. - # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc()), - parallel_config=(await self._get_parallel_config_rpc()), - enable_lora=bool(await self._get_lora_config_rpc()), - ) - - def close(self): - """Destroy the ZeroMQ Context.""" - # Close all sockets associated with this context and - # then terminate the context. - self.from_api_server.close() - self.to_rpc_server.close() - self.context.destroy() - - @contextmanager - def to_proxy_socket(self) -> Iterator[Socket]: - # Connect to the RPCServer via the proxy. - - # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some coroutines - # are still running requests. - # There should not be a race condition with this check because we don't - # yield to the event loop between here and opening the socket. - if self.context.closed: - raise RPCClientClosedError("The ZMQ client has already shut down") - - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - socket.set_hwm(VLLM_RPC_ZMQ_HWM) - try: - socket.connect(INPROC_PROXY_PATH) - yield socket - finally: - socket.close(linger=0) - - async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, - expected_type: Any, - error_message: str) -> Any: - """Send an RPC request that is expecting data back.""" - - with self.to_proxy_socket() as socket: - # Ping RPCServer with a request. - await socket.send_multipart((cloudpickle.dumps(request), ), - copy=False) - - # Make sure the server responds - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - data = pickle.loads(frame.buffer) - - if isinstance(data, Exception): - # Re-raise exceptions returned by the server - raise data - - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) - raise data - else: - raise ValueError(error_message) - - return data - - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[Socket] = None): - """Send one-way RPC request to trigger an action.""" - - async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - - await socket.send_multipart((cloudpickle.dumps(request), )) - - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - return pickle.loads(frame.buffer) - - # Make a new socket connection. - if socket is None: - with self.to_proxy_socket() as socket: - response = await do_rpc_call(socket, request) - - # Use existing socket connection. - else: - response = await do_rpc_call(socket, request) - - if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: - if isinstance(response, Exception): - logger.error(error_message) - raise response - raise ValueError(error_message) - - async def get_tokenizer(self, lora_request: LoRARequest): - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - - async def get_decoding_config(self) -> DecodingConfig: - return self.decoding_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def is_tracing_enabled(self) -> bool: - return self.tracing_flag - - async def _wait_for_server_rpc(self): - """Wait for the RPCServer to start up.""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server") - - async def _get_model_config_rpc(self) -> ModelConfig: - """Get the ModelConfig object from the RPC Server""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_MODEL_CONFIG, - expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server") - - async def _get_decoding_config_rpc(self) -> DecodingConfig: - """Get DecodingConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_DECODING_CONFIG, - expected_type=DecodingConfig, - error_message="Could not get DecodingConfig from RPC Server") - - async def _get_parallel_config_rpc(self) -> ParallelConfig: - """Get ParallelConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_PARALLEL_CONFIG, - expected_type=ParallelConfig, - error_message="Could not get ParallelConfig from RPC Server") - - async def _get_scheduler_config_rpc(self) -> SchedulerConfig: - """Get SchedulerConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server") - - async def _get_lora_config_rpc(self) -> LoRAConfig: - """Get LoRAConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_LORA_CONFIG, - expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server") - - async def _is_tracing_enabled_rpc(self) -> bool: - """Get is_tracing_enabled flag from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.IS_TRACING_ENABLED, - expected_type=bool, - error_message="Could not get is_tracing_enabled from RPC Server") - - async def abort(self, request_id: str): - """Send an ABORT_REQUEST signal to the RPC Server""" - - # Suppress timeouts as well. - # In cases where the server is busy processing requests and a very - # large volume of abort requests arrive, it is likely that the server - # will not be able to ack all of them in time. We have seen this when - # we abort 20k requests at once while another 2k are processing- many - # of them time out, but we see the server successfully abort all of the - # requests. - # In this case we assume that the server has received or will receive - # these abort requests, and ignore the timeout. This prevents a massive - # wall of `TimeoutError` stack traces. - with suppress(RPCClientClosedError, TimeoutError): - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") - - async def do_log_stats(self): - """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(RPCClientClosedError): - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") - - @property - def is_running(self) -> bool: - return not self._errored - - @property - def is_stopped(self) -> bool: - return self._errored - - @property - def errored(self) -> bool: - return self._errored - - async def generate( - self, - inputs: PromptInputs, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[RequestOutput, None]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - - finished = False - try: - with self.to_proxy_socket() as socket: - # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart((cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)), )) - - # Stream back the results from the RPC Server. - while not finished: - message = await socket.recv(copy=False) - assert isinstance(message, Frame) - request_output = pickle.loads(message.buffer) - - if isinstance(request_output, Exception): - # On exception, check if the server is still healthy - # possibly setting the `errored` property. - if not self._errored: - try: - await self.check_health(socket=socket) - except Exception as e: - self._errored = True - logger.exception(repr(e)) - - # NB: do before raising here so that the flag is set - # by the time the caller receives this exception - raise request_output - - finished = request_output.finished - yield request_output - - finally: - # Request was canceled by the client. - if not finished and not self._errored: - await self.abort(request_id) - - async def check_health(self, socket: Optional[Socket] = None) -> None: - """Raise if unhealthy""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_HEALTHY, - error_message="Got Unhealthy response from RPC Server", - socket=socket) - - async def encode(self, *args, - **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: - raise NotImplementedError( - "Embeddings not supported with multiprocessing backend") - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.START_PROFILE, - error_message="RPCRequest START_PROFILE failed.") - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py deleted file mode 100644 index 460ff0636b6e9..0000000000000 --- a/vllm/entrypoints/openai/rpc/server.py +++ /dev/null @@ -1,243 +0,0 @@ -import asyncio -import pickle -import signal -from typing import Any, Coroutine, Union - -import cloudpickle -import uvloop -import zmq -import zmq.asyncio -from typing_extensions import Never -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) -from vllm.logger import init_logger -from vllm.usage.usage_lib import UsageContext - -logger = init_logger(__name__) - -CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, - SchedulerConfig, LoRAConfig] - - -class AsyncEngineRPCServer: - - def __init__(self, async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - # Initialize engine first. - self.engine = AsyncLLMEngine.from_engine_args( - async_engine_args, usage_context=usage_context) - - # Initialize context. - self.context = zmq.asyncio.Context() - - # Init socket. - self.socket: Socket = self.context.socket(zmq.constants.DEALER) - self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) - self.socket.connect(rpc_path) - - def cleanup(self): - """Cleanup all resources.""" - self.socket.close() - self.context.destroy() - # Clear the engine reference so that it can be GC'ed. - del self.engine - - async def get_config(self, identity, request): - try: - config: CONFIG_TYPE - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - config = await self.engine.get_model_config() - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - config = await self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - config = await self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - config = await self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - config = await self.engine.get_parallel_config() - else: - raise ValueError("Unknown Config Request: %s", request) - - await self.socket.send_multipart((identity, pickle.dumps(config)), - copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def is_tracing_enabled(self, identity): - """Send the is_tracing_enabled flag""" - tracing_flag = await self.engine.is_tracing_enabled() - - await self.socket.send_multipart( - (identity, pickle.dumps(tracing_flag))) - - async def do_log_stats(self, identity): - """Log stats and confirm success.""" - await self.engine.do_log_stats() - - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def is_server_ready(self, identity): - """Notify the client that we are ready.""" - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def abort(self, identity, request: RPCAbortRequest): - """Abort request and notify the client of success.""" - try: - # Abort the request in the llm engine. - await self.engine.abort(request.request_id) - result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR - except Exception as e: - result = e - await self.socket.send_multipart((identity, pickle.dumps(result))) - - async def generate(self, identity, generate_request: RPCGenerateRequest): - try: - results_generator = self.engine.generate( - generate_request.inputs, - sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id, - lora_request=generate_request.lora_request, - trace_headers=generate_request.trace_headers, - prompt_adapter_request=generate_request.prompt_adapter_request) - - async for request_output in results_generator: - await self.socket.send_multipart( - (identity, pickle.dumps(request_output)), copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def check_health(self, identity): - try: - await self.engine.check_health() - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def start_profile(self, identity): - logger.info("Starting profiler...") - await self.engine.start_profile() - logger.info("Profiler started.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - async def stop_profile(self, identity): - logger.info("Stopping profiler...") - await self.engine.stop_profile() - logger.info("Profiler stopped.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - def _make_handler_coro(self, identity, - message: Frame) -> Coroutine[Any, Any, Never]: - """Route the zmq message to the handler coroutine.""" - - request = cloudpickle.loads(message.buffer) - - if isinstance(request, RPCGenerateRequest): - return self.generate(identity, request) - - elif isinstance(request, RPCAbortRequest): - return self.abort(identity, request) - - elif isinstance(request, RPCUtilityRequest): - if request in [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG - ]: - return self.get_config(identity, request) - elif request == RPCUtilityRequest.DO_LOG_STATS: - return self.do_log_stats(identity) - elif request == RPCUtilityRequest.IS_SERVER_READY: - return self.is_server_ready(identity) - elif request == RPCUtilityRequest.IS_SERVER_HEALTHY: - return self.check_health(identity) - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: - return self.is_tracing_enabled(identity) - elif request == RPCUtilityRequest.START_PROFILE: - return self.start_profile(identity) - elif request == RPCUtilityRequest.STOP_PROFILE: - return self.stop_profile(identity) - else: - raise ValueError(f"Unknown RPCUtilityRequest type: {request}") - - else: - raise ValueError(f"Unknown RPCRequest type: {request}") - - async def run_server_loop(self): - """Inner RPC Server Loop""" - - running_tasks = set() - while True: - # Wait for a request. - identity, message = await self.socket.recv_multipart(copy=False) - - # Process the request async. - task = asyncio.create_task( - self._make_handler_coro(identity, message)) - - # We need to keep around a strong reference to the task, - # to avoid the task disappearing mid-execution as running tasks - # can be GC'ed. Below is a common "fire-and-forget" tasks - # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - running_tasks.add(task) - task.add_done_callback(running_tasks.discard) - - -async def run_server(server: AsyncEngineRPCServer): - # Put the server task into the asyncio loop. - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - # Interruption handling. - def signal_handler() -> None: - # Kill the server on interrupt / terminate - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("vLLM ZMQ RPC Server was interrupted.") - finally: - # Clean up all resources. - server.cleanup() - - -def run_rpc_server(async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - - def signal_handler(*_) -> None: - # Interrupt server on sigterm while initializing - raise KeyboardInterrupt("AsyncEngineRPCServer terminated") - - signal.signal(signal.SIGTERM, signal_handler) - - server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path) - uvloop.run(run_server(server)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d28362a12abdb..b84898dc39b0f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,7 +9,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, apply_hf_chat_template, apply_mistral_chat_template, @@ -45,7 +45,7 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -57,7 +57,7 @@ def __init__(self, return_tokens_as_token_ids: bool = False, enable_auto_tools: bool = False, tool_parser: Optional[str] = None): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -105,6 +105,12 @@ async def create_chat_completion( logger.error("Error with model %s", error_check_ret) return error_check_ret + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + try: ( lora_request, @@ -112,8 +118,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) conversation, mm_data_future = parse_chat_messages_futures( request.messages, model_config, tokenizer) @@ -207,8 +212,8 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = ( - await self.async_engine_client.is_tracing_enabled()) + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) @@ -216,7 +221,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.async_engine_client.generate( + result_generator = self.engine_client.generate( engine_inputs, sampling_params, request_id, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 42142efb5f23e..14fa60243c584 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,7 +8,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -43,7 +43,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -52,7 +52,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -78,6 +78,12 @@ async def create_completion( if error_check_ret is not None: return error_check_ret + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + # Return error for unsupported features. if request.suffix is not None: return self.create_error_response( @@ -95,8 +101,7 @@ async def create_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -124,8 +129,8 @@ async def create_completion( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = ( - await self.async_engine_client.is_tracing_enabled()) + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) @@ -133,7 +138,7 @@ async def create_completion( raw_request.headers): log_tracing_disabled_warning() - generator = self.async_engine_client.generate( + generator = self.engine_client.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 12ec6be03cd62..f111a3a8277b5 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -8,7 +8,7 @@ from typing_extensions import assert_never from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, @@ -71,13 +71,13 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=None, @@ -118,8 +118,7 @@ async def create_embedding( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) pooling_params = request.to_pooling_params() @@ -144,7 +143,7 @@ async def create_embedding( "Prompt adapter is not supported " "for embedding models") - generator = self.async_engine_client.encode( + generator = self.engine_client.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ac74527441cd9..72f9381abc7db 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -64,7 +64,7 @@ class OpenAIServing: def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -75,7 +75,7 @@ def __init__( ): super().__init__() - self.async_engine_client = async_engine_client + self.engine_client = engine_client self.model_config = model_config self.max_model_len = model_config.max_model_len @@ -159,7 +159,7 @@ def create_streaming_error_response( async def _guided_decode_logits_processor( self, request: Union[ChatCompletionRequest, CompletionRequest], tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.async_engine_client.get_decoding_config() + decoding_config = await self.engine_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend return await get_guided_decoding_logits_processor( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 6e802b71ae2b4..8f8862897fc4e 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,7 +1,7 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (apply_hf_chat_template, apply_mistral_chat_template, load_chat_template, @@ -29,7 +29,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -37,7 +37,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -66,7 +66,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) prompt: Union[str, List[int]] if isinstance(request, TokenizeChatRequest): @@ -132,7 +132,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, diff --git a/vllm/envs.py b/vllm/envs.py index 6edb06ecd2e20..43c7aa8af85b2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -57,7 +57,7 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False - VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 + VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False @@ -393,8 +393,8 @@ def get_default_config_root(): # Time in ms for the zmq client to wait for a response from the backend # server for simple data operations - "VLLM_RPC_GET_DATA_TIMEOUT_MS": - lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), + "VLLM_RPC_TIMEOUT": + lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 7380b73ad6548..9ad240ef60820 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -106,6 +106,7 @@ def _init_executor(self) -> None: )) for rank in range(1, world_size) ] + self.worker_monitor = None if world_size != 1 or is_async: if is_async: async_worker_list = self.workers + [self.driver_worker] diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index aa2a16c04d08d..5bef76b90d332 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -168,6 +168,8 @@ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], self.tasks[task_id] = future try: self._task_queue.put((task_id, method, args, kwargs)) + except SystemExit: + raise except BaseException as e: del self.tasks[task_id] raise ChildProcessError("worker died") from e @@ -222,6 +224,8 @@ def _run_worker_process( try: executor = getattr(worker, method) output = executor(*args, **kwargs) + except SystemExit: + raise except KeyboardInterrupt: break except BaseException as e: From a8c1d161a7d87dbc6c7cccfce303dcbe2e4ed6be Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Wed, 18 Sep 2024 11:38:43 -0400 Subject: [PATCH 030/116] [Core] *Prompt* logprobs support in Multi-step (#8199) --- tests/conftest.py | 84 +++++++++++------- tests/models/utils.py | 108 +++++++++++++++++++++-- tests/multi_step/test_correctness_llm.py | 92 +++++++++++++++++++ tests/utils.py | 3 +- vllm/worker/multi_step_model_runner.py | 72 ++++++++++----- 5 files changed, 300 insertions(+), 59 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e9c7fc7bf9c67..c2616bcf7091c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,8 @@ BatchFeature) from transformers.models.auto.auto_factory import _BaseAutoModelClass +from tests.models.utils import (TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs) from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset @@ -33,7 +35,6 @@ to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput -from vllm.sequence import SampleLogprobs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, identity, is_cpu) @@ -469,7 +470,7 @@ def generate_greedy_logprobs_limit( audios: Optional[PromptAudioInput] = None, videos: Optional[List[np.ndarray]] = None, **kwargs: Any, - ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + ) -> List[TokensTextLogprobs]: all_logprobs: List[List[Dict[int, float]]] = [] all_output_ids: List[List[int]] = [] all_output_strs: List[str] = [] @@ -525,7 +526,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( max_tokens: int, num_logprobs: int, **kwargs: Any, - ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + ) -> List[TokensTextLogprobs]: ''' Greedy logprobs generation for vLLM encoder/decoder models ''' @@ -653,14 +654,16 @@ def generate( @staticmethod def _final_steps_generate_w_logprobs( req_outputs: List[RequestOutput], - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] + ) -> List[TokensTextLogprobsPromptLogprobs]: + outputs: List[TokensTextLogprobsPromptLogprobs] = [] for req_output in req_outputs: + assert len(req_output.outputs) > 0 for sample in req_output.outputs: output_str = sample.text output_ids = list(sample.token_ids) output_logprobs = sample.logprobs - outputs.append((output_ids, output_str, output_logprobs)) + outputs.append((output_ids, output_str, output_logprobs, + req_output.prompt_logprobs)) return outputs def generate_w_logprobs( @@ -670,7 +673,8 @@ def generate_w_logprobs( images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: assert sampling_params.logprobs is not None if images is not None: @@ -695,13 +699,20 @@ def generate_w_logprobs( req_outputs = self.model.generate(inputs, sampling_params=sampling_params) - return self._final_steps_generate_w_logprobs(req_outputs) + + toks_str_logsprobs_prompt_logprobs = ( + self._final_steps_generate_w_logprobs(req_outputs)) + # Omit prompt logprobs if not required by sampling params + return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None else + toks_str_logsprobs_prompt_logprobs) def generate_encoder_decoder_w_logprobs( self, encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], sampling_params: SamplingParams, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: ''' Logprobs generation for vLLM encoder/decoder models ''' @@ -709,7 +720,12 @@ def generate_encoder_decoder_w_logprobs( assert sampling_params.logprobs is not None req_outputs = self.model.generate(encoder_decoder_prompts, sampling_params=sampling_params) - return self._final_steps_generate_w_logprobs(req_outputs) + toks_str_logsprobs_prompt_logprobs = ( + self._final_steps_generate_w_logprobs(req_outputs)) + # Omit prompt logprobs if not required by sampling params + return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None else + toks_str_logsprobs_prompt_logprobs) def generate_greedy( self, @@ -727,44 +743,48 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, num_logprobs: int, + num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, stop_token_ids: Optional[List[int]] = None, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens, - logprobs=num_logprobs, - stop_token_ids=stop_token_ids) - outputs = self.generate_w_logprobs(prompts, - greedy_logprobs_params, - images=images, - audios=audios, - videos=videos) - - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: + greedy_logprobs_params = SamplingParams( + temperature=0.0, + max_tokens=max_tokens, + logprobs=num_logprobs, + prompt_logprobs=(num_prompt_logprobs), + stop_token_ids=stop_token_ids) + + return self.generate_w_logprobs(prompts, + greedy_logprobs_params, + images=images, + audios=audios, + videos=videos) def generate_encoder_decoder_greedy_logprobs( self, encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams(temperature=0.0, - use_beam_search=False, - max_tokens=max_tokens, - logprobs=num_logprobs) + num_prompt_logprobs: Optional[int] = None, + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: + greedy_logprobs_params = SamplingParams( + temperature=0.0, + use_beam_search=False, + max_tokens=max_tokens, + logprobs=num_logprobs, + prompt_logprobs=(num_prompt_logprobs), + ) ''' Greedy logprobs generation for vLLM encoder/decoder models ''' - outputs = self.generate_encoder_decoder_w_logprobs( + return self.generate_encoder_decoder_w_logprobs( encoder_decoder_prompts, greedy_logprobs_params) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] - def generate_beam_search( self, prompts: List[str], diff --git a/tests/models/utils.py b/tests/models/utils.py index 93ec03995094b..8e31a1d6eefed 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,7 +1,7 @@ import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union -from vllm.sequence import Logprob, SampleLogprobs +from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs TokensText = Tuple[List[int], str] @@ -34,20 +34,47 @@ def check_outputs_equal( assert output_ids_0 == output_ids_1, fail_msg +# Representation of generated sequence as a tuple of +# * Token ID list +# * String +# * List of top sample logprobs for each sampled token +# +# Assumes prompt logprobs were not requested. TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]]] -# Allow for tokens to be represented as str's rather than IDs +# Allow for tokens to be represented as str's rather than IDs; +# tuple of +# * Token string representations list +# * String +# * Optional list of top sample logprobs for each sampled token +# +# Assumes prompt logprobs were not requested. TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]], List[Dict[str, Logprob]]]]] +# Representation of generated sequence as a tuple of +# * Token ID list +# * String +# * Optional list of top sample logprobs for each sampled token +# * Optional list of top prompt logprobs for each prompt token +# +# Allows prompt logprobs to be requested. +TokensTextLogprobsPromptLogprobs = Tuple[ + List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]], + Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]] + def check_logprobs_close( *, - outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], - outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], + outputs_0_lst: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + TextTextLogprobs]], + outputs_1_lst: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + TextTextLogprobs]], name_0: str, name_1: str, num_outputs_0_skip_tokens: int = 0, @@ -57,6 +84,18 @@ def check_logprobs_close( """Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. + How sample logprobs are compared: + * `always_check_logprobs == True`: set of highest-logprob token ids + must match between seq0 and seq1 at all sampled token offsets + * `always_check_logprobs == False`: highest-logprob token ids are + only compared at sampled token offsets for which generated token + ids don't match + + Prompt logprobs must be provided either for both input sequences, or + for neither. If prompt logprobs are provided, then highest-logprob + prompt token ids must match between seq0 and seq1 at all prompt token + offsets. + Args: outputs_0_lst: First sequence to compare outputs_0_lst: Second sequence to compare @@ -78,8 +117,65 @@ def check_logprobs_close( for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)): - output_ids_0, output_str_0, logprobs_0 = outputs_0 - output_ids_1, output_str_1, logprobs_1 = outputs_1 + assert len(outputs_0) == len(outputs_1) + if len(outputs_0) == 3: + assert len(outputs_1) == 3 + # Break out tokens, text & sample logprobs + # (prompt logprobs were not provided) + output_ids_0, output_str_0, logprobs_0 = outputs_0 + output_ids_1, output_str_1, logprobs_1 = outputs_1 + elif len(outputs_0) == 4: + assert len(outputs_1) == 4 + # Break out tokens, text, sample logprobs & prompt logprobs + ( + output_ids_0, + output_str_0, + logprobs_0, + prompt_logprobs_0, + ) = outputs_0 + ( + output_ids_1, + output_str_1, + logprobs_1, + prompt_logprobs_1, + ) = outputs_1 + + # Test prompt logprobs closeness + if (prompt_logprobs_0 is not None + and prompt_logprobs_1 is not None): + # Both sequences' prompt logprobs lists are not `None`` + # (although individual list elements may be `None`); + # for each token's logprobs: + for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate( + zip(prompt_logprobs_0, prompt_logprobs_1)): + fail_msg = ( + f"Prompt logprobs test:" + f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}" + f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}") + + if logprobs_elem_0 is None: + # If the seq 0 token's logprobs are `None`, + # the seq 1 token's logprobs must be `None` + assert logprobs_elem_1 is None, fail_msg + else: + # If the seq 0 token's logprobs are not `None`, + # the seq 1 token's logprobs must not be `None` + assert logprobs_elem_1 is not None, fail_msg + # Logprobs check: top-k token choices must be the same + assert (set(logprobs_elem_0.keys()) == set( + logprobs_elem_1.keys())), fail_msg + else: + # Both sequence logprobs lists must be `None` + fail_msg = (f"Prompt logprobs test:" + f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}" + f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}") + + assert (prompt_logprobs_0 is None + and prompt_logprobs_1 is None), fail_msg + else: + raise ValueError(f"Outputs tuple must have 3 or 4 elements but " + f"{len(outputs_0)} elements were provided: " + f"{outputs_0}") if logprobs_0 is None: logprobs_0 = [None] * len(output_ids_0) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index 24ebb60a9cbfd..c5dc81cc25622 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -100,3 +100,95 @@ def test_multi_step_llm( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)]) +def test_multi_step_llm_w_prompt_logprobs( + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp_size: int, + max_tokens: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + num_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], +) -> None: + """Test prompt logprobs with multi-step scheduling via sync LLM Engine. + + Set up a vLLM engine instance w/ single-step scheduling as a ground-truth + reference. + + Prompt them with the same example prompts. + + Validate: + * All generated logprobs are all very close + + Args: + hf_runner: HF transformers model runner fixture + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + tp_size: degree of tensor-parallelism + max_tokens: the maximum number of tokens to generate + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> no logprobs + num_prompt_logprobs: number of logprobs to return for each prompt token; + note that this argument is not supported by the + OpenAI completions endpoint. + """ + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs, + num_prompt_logprobs=num_prompt_logprobs) + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + ) as vllm_model: + single_step_vllm_outputs = vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs, + num_prompt_logprobs=num_prompt_logprobs) + + check_logprobs_close( + outputs_0_lst=single_step_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/utils.py b/tests/utils.py index 81442cad78da2..43825e8138362 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -493,6 +493,7 @@ async def completions_with_server_args( ''' outputs = None + max_wait_seconds = 240 * 3 # 240 is default with RemoteOpenAIServer(model_name, server_cli_args, max_wait_seconds=max_wait_seconds) as server: @@ -503,7 +504,7 @@ async def completions_with_server_args( stream=False, max_tokens=5, logprobs=num_logprobs) - assert outputs is not None + assert outputs is not None, "Completion API call failed." return outputs diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index b900eb5a610ff..ebcafbbab119a 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -614,34 +614,66 @@ def _pythonize_sampler_output( frozen_model_input = model_input.frozen_model_input assert frozen_model_input.sampling_metadata is not None + sampling_metadata = frozen_model_input.sampling_metadata # samples generation should have been skipped assert not output.outputs pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] - # CPU GPU sync - pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False) + # We guarantee output tensors are ready, so it is safe to + # pythonize the sampler output & obtain CPU-side logprobs. + # + # However we should check whether logprobs pythonization may + # be skipped entirely, i.e. because no logprobs were requested + # or pythonization was not deferred. To that end, + # + # * `prompt_logprobs_are_requested_for_prefill` signals that + # there are *any* prefill-phase requests which specify that + # prompt logprobs should be returned. + # + # * `any_logprobs_are_requested` signals that there are any + # requests which (1) specify that sample logprobs should be + # returned, or (2) are in the prefill phase AND specify that + # prompt logprobs should be returned. + # + # Later on, these flags cause adjustments to the pythonization + # process to accommodate logprobs. + + seq_groups = sampling_metadata.seq_groups + prompt_logprobs_are_requested_for_prefill = any([ + sg.sampling_params.prompt_logprobs is not None and sg.is_prompt + for sg in seq_groups + ]) + any_logprobs_are_requested = ( + prompt_logprobs_are_requested_for_prefill + or any([sg.sampling_params.logprobs is not None for sg in seq_groups])) + + if prompt_logprobs_are_requested_for_prefill: + # CPU GPU sync, after gathering *only* sampled tokens (since + # requesting prompt logprobs leads `sampled_token_ids` to + # include prompt token ids in addition to sampled token ids.) + sample_idx_tensor = torch.tensor( + [sdx for sg in seq_groups for sdx in sg.sample_indices]) + pinned_buffer = pinned_buffer.copy_( + sampled_token_ids[sample_idx_tensor, :], non_blocking=False) + else: + # CPU GPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, + non_blocking=False) # this will not block as the tensors are already on CPU samples_list = pinned_buffer.tolist() - sampling_metadata = frozen_model_input.sampling_metadata - skip_sampler_cpu_output = ( frozen_model_input.sampling_metadata.skip_sampler_cpu_output) - # We are guaranteed output tensors are ready, so it is safe to - # pythonize the sampler output & obtain CPU-side logprobs. - # - # However this computation may be skipped entirely - # if no pythonization was deferred. - seq_groups = sampling_metadata.seq_groups - logprobs_are_requested = any([ - sg.sampling_params.logprobs is not None - or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups - ]) + # *Don't* skip logprobs pythonization *if*: + # * Any requests require logprobs to be returned in this + # iteration AND + # * These requests are being scheduled in a fashion which + # defers pythonization (i.e. multi-step scheduling.) do_pythonize_logprobs = (skip_sampler_cpu_output - and logprobs_are_requested) + and any_logprobs_are_requested) ( prompt_logprobs, sample_logprobs, @@ -666,7 +698,7 @@ def _pythonize_sampler_output( prompt_logprobs[sgdx], sample_logprobs[sgdx], ) - elif logprobs_are_requested: + elif any_logprobs_are_requested: ( group_prompt_logprobs, group_sample_logprobs, @@ -696,7 +728,7 @@ def _pythonize_sampler_output( seq_output.parent_seq_id = seq_ids[parent_id] seq_output.output_token = next_token_id - if logprobs_are_requested: + if any_logprobs_are_requested: seq_output.logprobs = group_sample_logprobs[tdx] else: logprobs = next(iter(seq_output.logprobs.values())) @@ -714,7 +746,7 @@ def _pythonize_sampler_output( seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, (group_sample_logprobs[tdx] - if logprobs_are_requested else { + if any_logprobs_are_requested else { next_token_id: Logprob(logprob=float('inf'), rank=None, @@ -722,12 +754,12 @@ def _pythonize_sampler_output( }))) if cache is not None: completion_seq_group_output.prompt_logprobs = \ - group_prompt_logprobs if logprobs_are_requested else None + group_prompt_logprobs if any_logprobs_are_requested else None output.outputs.append(completion_seq_group_output) else: output.outputs.append( CompletionSequenceGroupOutput( seq_outputs, (group_prompt_logprobs - if logprobs_are_requested else None))) + if any_logprobs_are_requested else None))) assert len(output.outputs) > 0 From d65798f78c76f03f068fc2f69a68cff430ee6b6f Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 18 Sep 2024 12:10:27 -0400 Subject: [PATCH 031/116] [Core] zmq: bind only to 127.0.0.1 for local-only usage (#8543) Signed-off-by: Russell Bryant --- .../device_communicators/shm_broadcast.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index d4847542688c0..b507cd2e1cddb 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -196,7 +196,9 @@ def __init__( # see http://api.zeromq.org/3-3:zmq-setsockopt for more details self.local_socket.setsockopt(XPUB_VERBOSE, True) local_subscribe_port = get_open_port() - self.local_socket.bind(f"tcp://*:{local_subscribe_port}") + socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}" + logger.debug("Binding to %s", socket_addr) + self.local_socket.bind(socket_addr) self.current_idx = 0 @@ -212,7 +214,8 @@ def __init__( self.remote_socket = context.socket(XPUB) self.remote_socket.setsockopt(XPUB_VERBOSE, True) remote_subscribe_port = get_open_port() - self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") + socket_addr = f"tcp://*:{remote_subscribe_port}" + self.remote_socket.bind(socket_addr) else: remote_subscribe_port = None @@ -255,8 +258,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.local_socket = context.socket(SUB) self.local_socket.setsockopt_string(SUBSCRIBE, "") - self.local_socket.connect( - f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}") + socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.local_socket.connect(socket_addr) self.remote_socket = None else: @@ -270,8 +274,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") - self.remote_socket.connect( - f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}") + socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.remote_socket.connect(socket_addr) return self From e18749ff09c277f7cdab278895ebdd9b1041b6e8 Mon Sep 17 00:00:00 2001 From: "Geun, Lim" Date: Thu, 19 Sep 2024 02:04:00 +0900 Subject: [PATCH 032/116] [Model] Support Solar Model (#8386) Co-authored-by: Michael Goin --- docs/source/models/supported_models.rst | 4 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/solar.py | 580 ++++++++++++++++++++ vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/solar.py | 245 +++++++++ 6 files changed, 834 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/solar.py create mode 100644 vllm/transformers_utils/configs/solar.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 3dcc242803752..745b4b8e2e0eb 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -179,6 +179,10 @@ Decoder-only Language Models - Starcoder2 - :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc. - + * - :code:`SolarForCausalLM` + - EXAONE-3 + - :code:`upstage/solar-pro-preview-instruct`, etc. + - * - :code:`XverseForCausalLM` - Xverse - :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 41c8e754377c7..591007e787f47 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -60,6 +60,7 @@ "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "SolarForCausalLM": ("solar", "SolarForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py new file mode 100644 index 0000000000000..16e576d0ac29c --- /dev/null +++ b/vllm/model_executor/models/solar.py @@ -0,0 +1,580 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Solar model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.model_executor.models.utils import (PPMissingLayer, + is_pp_missing_parameter, + make_layers) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import is_hip + + +class SolarMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class SolarAttention(nn.Module): + + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class SolarDecoderLayer(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] \ + = config.original_max_position_embeddings + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = SolarAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = SolarMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class SolarModel(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: SolarDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + bskcn_h_1 = None + bskcn_h_2 = None + bskcn_r_1 = None + bskcn_r_2 = None + bskcn_tv = (self.config.bskcn_tv[0] + if self.training else self.config.bskcn_tv[1]) + + for i in range(self.start_layer, self.end_layer): + if i in self.config.bskcn_1: + bskcn_h_1 = hidden_states.clone() + bskcn_r_1 = residual.clone() + if i in self.config.bskcn_2: + bskcn_h_2 = hidden_states.clone() + bskcn_r_2 = residual.clone() + if i in self.config.bskcn_3: + hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * ( + 1 - bskcn_tv) + residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv) + if i in self.config.bskcn_4: + hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * ( + 1 - bskcn_tv) + residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv) + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class SolarForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = SolarModel( + config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model", + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + "residual": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.model.layers[layer_idx], nn.Identity): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3c269bc10cdf8..1744935d624fb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -24,7 +24,7 @@ JAISConfig, MedusaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, RWConfig, - UltravoxConfig) + SolarConfig, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file @@ -50,6 +50,7 @@ "exaone": ExaoneConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, + "solar": SolarConfig, "ultravox": UltravoxConfig, # Granite can be removed from here once we have upgraded to # transformers 4.45+ diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8381c5227584e..ea4fc8ad21f35 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -13,6 +13,7 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ @@ -27,6 +28,7 @@ "ExaoneConfig", "MLPSpeculatorConfig", "NemotronConfig", + "SolarConfig", "UltravoxConfig", # Granite can be removed from here once we have upgraded to # transformers 4.45+ diff --git a/vllm/transformers_utils/configs/solar.py b/vllm/transformers_utils/configs/solar.py new file mode 100644 index 0000000000000..d5113bf01695a --- /dev/null +++ b/vllm/transformers_utils/configs/solar.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Solar model configuration""" + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class SolarConfig(PretrainedConfig): + r""" + This is the configuration class to store + the configuration of a [`SolarModel`]. + It is used to instantiate an LLaMA model + according to the specified arguments, + defining the model architecture. + Instantiating a configuration with the + defaults will yield a similar + configuration to that of the LLaMA-7B. + Configuration objects inherit from [`PretrainedConfig`] + and can be used to control the model outputs. + Read the documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. + Defines the number of different tokens + that can be represented by the `inputs_ids` + passed when calling [`SolarModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer + in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that + should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, + the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model + will use Multi Query Attention (MQA) + otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, + each group key and value head should be constructed + by meanpooling all the original heads within that group. + For more details checkout [this paper] + (https://arxiv.org/pdf/2305.13245.pdf). + If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) + in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + Solar 1 supports up to 2048 tokens, + Solar 2 up to 4096, CodeSolar up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of + the truncated_normal_initializer for initializing + all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return + the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank + used during pretraining. + Please refer to [this + document](https://huggingface.co/docs/ + transformers/main/ + perf_train_gpu_many#tensor-parallelism) + to understand more about it. This value is + necessary to ensure exact reproducibility + of the pretraining results. + Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for + the RoPE embeddings. + Currently supports two scaling + strategies: linear and dynamic. + Their scaling factor must be a float greater than 1. + The expected format is + `{"type": strategy name, "factor": scaling factor}`. + When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/ + dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking + API changes in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value + and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj + layers in the MLP layers. + sliding_window (`int`, *optional*, defaults to 2047): + Sliding window attention window size. If not specified, + will default to `2047`. + ```python + >>> from transformers import SolarModel, SolarConfig + >>> # Initializing a Solar-pro style configuration + >>> configuration = SolarConfig() + >>> # Initializing a model from the Solar-pro style configuration + >>> model = SolarModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "solar" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + sliding_window=2047, + bskcn_1=None, + bskcn_2=None, + bskcn_3=None, + bskcn_4=None, + bskcn_tv=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.sliding_window = sliding_window + self.bskcn_1 = bskcn_1 if bskcn_1 is not None else [12, 20, 32, 44] + self.bskcn_2 = bskcn_2 if bskcn_2 is not None else [20, 32] + self.bskcn_3 = bskcn_3 if bskcn_3 is not None else [16, 24, 36, 48] + self.bskcn_4 = bskcn_4 if bskcn_4 is not None else [28, 40] + self.bskcn_tv = bskcn_tv if bskcn_tv is not None else [0.9, 0.8] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if (not isinstance(self.rope_scaling, dict) + or len(self.rope_scaling) != 2): + raise ValueError( + "`rope_scaling` must be a dictionary with two fields," + " `type` and `factor`, " + f"got {self.rope_scaling}") + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in [ + "linear", + "dynamic", + ]: + raise ValueError(f"`rope_scaling`'s type field must be one of " + f"['linear', 'dynamic'], got {rope_scaling_type}") + if (rope_scaling_factor is None + or not isinstance(rope_scaling_factor, float) + or rope_scaling_factor <= 1.0): + raise ValueError( + f"`rope_scaling`'s factor field must be a float > 1," + f" got {rope_scaling_factor}") From b3195bc9e4d57b6107af2222afea26c51475e262 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 18 Sep 2024 13:41:08 -0400 Subject: [PATCH 033/116] [AMD][ROCm]Quantization methods on ROCm; Fix _scaled_mm call (#8380) Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Co-authored-by: Michael Goin --- vllm/config.py | 5 +- .../schemes/compressed_tensors_w8a8_fp8.py | 29 +++++++++-- .../layers/quantization/fbgemm_fp8.py | 15 +++++- .../layers/quantization/utils/w8a8_utils.py | 49 +++++++++++-------- 4 files changed, 71 insertions(+), 27 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9d42b75c1c462..7a15606836dcc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -255,7 +255,10 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = ["awq", "gptq", "fp8"] + rocm_supported_quantization = [ + "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", + "fbgemm_fp8" + ] optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 8a3d24e2fd258..5931ec36c97d5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -8,10 +8,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationStrategy) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale) + apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) +from vllm.utils import is_hip __all__ = ["CompressedTensorsW8A8Fp8"] @@ -39,16 +41,37 @@ def process_weights_after_loading(self, layer) -> None: logical_widths=layer.logical_widths, ) + if is_hip(): + weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=max_w_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) # If channelwise, scales are already lined up, so just transpose. elif self.strategy == QuantizationStrategy.CHANNEL: weight = layer.weight + + if is_hip(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + else: + weight_scale = layer.weight_scale.data + layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter - layer.weight_scale = Parameter(layer.weight_scale.data, - requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) else: raise ValueError(f"Unknown quantization strategy {self.strategy}") diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index eb59344f36d2e..f26907176ad1a 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -15,10 +15,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear) + apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter) from vllm.platforms import current_platform +from vllm.utils import is_hip logger = init_logger(__name__) @@ -125,8 +126,18 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight = Parameter(layer.weight.data, requires_grad=False) weight = layer.weight - layer.weight = Parameter(weight.t(), requires_grad=False) + if is_hip(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=None) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) if self.quant_config.use_marlin: prepare_fp8_layer_for_marlin(layer) # Activations not quantized for marlin. diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index d86fea63d8a1b..fb263d121fe55 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,11 +6,9 @@ from vllm.platforms import current_platform from vllm.utils import is_hip -# scaled_mm in pytorch on rocm has a bug that requires always -# providing scaling factor for result. This value is created -# as global value to avoid multiple tensor allocations, and -# can be removed once pytorch fixes the bug. -TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None +# Input scaling factors are no longer optional in _scaled_mm starting +# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale +TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None def cutlass_fp8_supported() -> bool: @@ -131,19 +129,17 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output = torch._scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - scale_result=TORCH_SCALED_MM_SCALE_RESULT, - bias=bias) - # Since in torch 2.5, scaled_mm only returns single value - # This should be removed when vllm-nvidia also moves to 2.5 - if is_hip(): - return torch.narrow(output, 0, 0, input.shape[0]) - return torch.narrow(output[0], 0, 0, input.shape[0]) + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + return torch.narrow(output[0], 0, 0, input.shape[0]) + return torch.narrow(output, 0, 0, input.shape[0]) else: # Fallback for channelwise case, where we use unfused DQ @@ -161,12 +157,23 @@ def apply_fp8_linear( # For the scaled_mm fallback case, we break this down, since it # does not support s_w being a vector. + # Making sure the dummy tensor is on the same device as the weight + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY.device != weight.device: + TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) + # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=torch.float32) + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] # Unpad (undo num_token_padding) output = torch.narrow(output, 0, 0, input.shape[0]) x_scale = torch.narrow(x_scale, 0, 0, input.shape[0]) From db9120cdedba5033037432775417df0b6117495d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 18 Sep 2024 16:05:06 -0400 Subject: [PATCH 034/116] [Kernel] Change interface to Mamba selective_state_update for continuous batching (#8039) --- tests/kernels/test_mamba_ssm.py | 146 ++++++++++++++++++ .../layers/mamba/ops/mamba_ssm.py | 31 +++- 2 files changed, 174 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index f582445692344..366475222a68e 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -323,3 +323,149 @@ def test_selective_state_update(dim, dstate, has_z, itype): assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 7e-2, 7e-2 + if torch.version.hip: + atol *= 2 + # set seed + torch.random.manual_seed(0) + batch_size = 16 + + total_entries = 10 * batch_size + state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) + state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + + x = torch.randn(batch_size, dim, device=device, dtype=itype) + dt = torch.randn(batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state[state_indices, :].detach().clone() + out = selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices) + out_ref = selective_state_update_ref(state_ref, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + + assert torch.allclose(state[state_indices, :], + state_ref, + rtol=rtol, + atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("tie_hdim", [False, True]) +@pytest.mark.parametrize("ngroups", [1, 2, 4]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dim", [2048, 4096]) +def test_selective_state_update_with_heads_with_batch_indices( + dim, dstate, ngroups, has_z, tie_hdim, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-1, 1e-1 + # set seed + torch.random.manual_seed(0) + batch_size = 16 + headdim = 64 + nheads = dim // headdim + + total_entries = 10 * batch_size + state = torch.randn(total_entries, + nheads, + headdim, + dstate, + dtype=itype, + device=device) + state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + + x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) + if not tie_hdim: + dt = torch.randn(batch_size, + nheads, + headdim, + device=device, + dtype=itype) + dt_bias = torch.rand(nheads, headdim, device=device) - 4.0 + A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0 + D = torch.randn(nheads, headdim, device=device) + else: + dt = repeat(torch.randn(batch_size, nheads, device=device, + dtype=itype), + "b h -> b h p", + p=headdim) + dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, + "h -> h p", + p=headdim) + A = repeat(-torch.rand(nheads, device=device) - 1.0, + "h -> h p n", + p=headdim, + n=dstate) + D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim) + B = torch.randn(batch_size, ngroups, dstate, device=device) + C = torch.randn(batch_size, ngroups, dstate, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state[state_indices, :].detach().clone() + out = selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices) + out_ref = selective_state_update_ref(state_ref, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.allclose(state[state_indices, :], + state_ref, + rtol=rtol, + atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 869c69214caf2..a0bed07ac6193 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py import torch import triton @@ -27,6 +28,10 @@ def softplus(dt): {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({ + "HAS_STATE_BATCH_INDICES": + lambda args: args["state_batch_indices_ptr"] is not None +}) @triton.heuristics( {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) @triton.jit @@ -42,6 +47,7 @@ def _selective_scan_update_kernel( D_ptr, z_ptr, out_ptr, + state_batch_indices_ptr, # Matrix dimensions batch, nheads, @@ -85,12 +91,24 @@ def _selective_scan_update_kernel( HAS_DT_BIAS: tl.constexpr, HAS_D: tl.constexpr, HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) - state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + + # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate + # is taken from the state_batch_indices_ptr Otherwise, the state coordinate + # is the same as the batch id. + if HAS_STATE_BATCH_INDICES: + state_batch_indices_ptr += pid_b + state_batch_idx = tl.load(state_batch_indices_ptr) + state_ptr += (state_batch_idx * stride_state_batch + + pid_h * stride_state_head) + else: + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head if HAS_DT_BIAS: @@ -177,7 +195,8 @@ def selective_state_update(state, D=None, z=None, dt_bias=None, - dt_softplus=False): + dt_softplus=False, + state_batch_indices=None): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -211,7 +230,10 @@ def selective_state_update(state, z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) - batch, nheads, dim, dstate = state.shape + + _, nheads, dim, dstate = state.shape + batch = x.shape[0] + assert x.shape == (batch, nheads, dim) assert dt.shape == x.shape assert A.shape == (nheads, dim, dstate) @@ -225,6 +247,8 @@ def selective_state_update(state, assert z.shape == x.shape if dt_bias is not None: assert dt_bias.shape == (nheads, dim) + if state_batch_indices is not None: + assert state_batch_indices.shape == (batch, ) out = torch.empty_like(x) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else @@ -249,6 +273,7 @@ def selective_state_update(state, D, z, out, + state_batch_indices, batch, nheads, dim, From d9cd78eb718c233ebc5b84377fc2226af7ef0fa2 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 18 Sep 2024 21:17:55 +0100 Subject: [PATCH 035/116] [BugFix] Nonzero exit code if MQLLMEngine startup fails (#8572) --- vllm/entrypoints/openai/api_server.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1b9eb30252417..fd6f36e8768dd 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -11,7 +11,7 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Optional, Set +from typing import AsyncIterator, Set import uvloop from fastapi import APIRouter, FastAPI, Request @@ -95,7 +95,7 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[Optional[EngineClient]]: + args: Namespace) -> AsyncIterator[EngineClient]: # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit @@ -110,7 +110,7 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, -) -> AsyncIterator[Optional[EngineClient]]: +) -> AsyncIterator[EngineClient]: """ Create EngineClient, either: - in-process using the AsyncLLMEngine Directly @@ -188,10 +188,8 @@ async def build_async_engine_client_from_engine_args( break except TimeoutError: if not engine_process.is_alive(): - logger.error("Engine process died before responding " - "to readiness probe") - yield None - return + raise RuntimeError( + "Engine process failed to start") from None yield mp_engine_client # type: ignore[misc] finally: @@ -532,10 +530,6 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) async with build_async_engine_client(args) as engine_client: - # If None, creation of the client failed and we exit. - if engine_client is None: - return - app = build_app(args) model_config = await engine_client.get_model_config() From 0d47bf3bf40edfe9fcfd7e5cd909388497535bc5 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 18 Sep 2024 16:10:01 -0600 Subject: [PATCH 036/116] [Bugfix] add `dead_error` property to engine client (#8574) Signed-off-by: Joe Runde --- vllm/engine/multiprocessing/client.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 18b620c74ddf9..2cb4de79131f1 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -380,6 +380,13 @@ def is_stopped(self) -> bool: def errored(self) -> bool: return self._errored_with is not None + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + async def generate( self, inputs: PromptInputs, From 4c34ce8916da0e4967eadefcb7f91eb58dd7ac61 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 18 Sep 2024 21:42:49 -0400 Subject: [PATCH 037/116] [Kernel] Remove marlin moe templating on thread_m_blocks (#8573) Co-authored-by: lwilkinson@neuralmagic.com --- csrc/moe/marlin_moe_ops.cu | 79 ++++++++++++++------------------------ 1 file changed, 28 insertions(+), 51 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 666d87eb92595..49cc03f827f68 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1342,9 +1342,6 @@ __device__ inline void MarlinMoESingle( template shared @@ -1459,9 +1456,6 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, template shared @@ -1515,26 +1509,24 @@ const int STAGES = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; -#define __CALL_IF_MOE(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, \ - NUM_THREADS) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks); \ +#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks); \ } typedef struct { @@ -1711,31 +1703,16 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, From 3118f63385c0d767fba8b6d2039fc35440678da9 Mon Sep 17 00:00:00 2001 From: sroy745 <142070531+sroy745@users.noreply.github.com> Date: Wed, 18 Sep 2024 19:24:15 -0700 Subject: [PATCH 038/116] [Bugfix] [Encoder-Decoder] Bugfix for encoder specific metadata construction during decode of encoder-decoder models. (#8545) --- .../test_encoder_decoder_model_runner.py | 88 +++++++++++++------ vllm/worker/enc_dec_model_runner.py | 12 +-- 2 files changed, 69 insertions(+), 31 deletions(-) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index c0654712b71b5..27cdf5f339ede 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -273,7 +273,8 @@ def test_prepare_prompt(batch_size): "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) -def test_prepare_decode(batch_size): +@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) +def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): ''' Test the ability of the encoder/decoder model runner subclass to produce decode-phase model inputs & attention metadata. @@ -288,6 +289,7 @@ def test_prepare_decode(batch_size): Arguments: * batch_size + * multiple_seqs_per_seq_group * backend_name: The attention backend under test * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) ''' @@ -305,22 +307,29 @@ def test_prepare_decode(batch_size): seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] - block_tables = {0: [1]} + block_tables = { + 0: [1], + 1: [3] + } if multiple_seqs_per_seq_group else { + 0: [1] + } cross_block_table = [2] for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) encoder_seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, - seq_data={0: seq_data}, + seq_data={ + 0: seq_data, + 1: seq_data + } if multiple_seqs_per_seq_group else {0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, @@ -328,6 +337,10 @@ def test_prepare_decode(batch_size): ) assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) + seq_lens.extend( + [seq_len for _ in range(len(seq_group_metadata.seq_data))]) + encoder_seq_lens.extend( + [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) # Build # * Decoder model inputs @@ -398,19 +411,24 @@ def test_prepare_decode(batch_size): # Verify block tables are correct for prompts # - Decoder self-attention - expected = torch.tensor( - [block_tables[0] for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) + flattened_block_tables = [ + block_table for block_table in block_tables.values() + ] + expected = torch.tensor(flattened_block_tables * + len(seq_group_metadata_list), + dtype=torch.int32, + device=model_runner.device) assert torch.equal( attn_metadata.block_tables, expected, ) # - Encoder/decoder cross-attention - expected = torch.tensor( - [cross_block_table for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device) + expected = torch.tensor([ + cross_block_table for seq_group_metadata in seq_group_metadata_list + for _ in range(len(seq_group_metadata.seq_data)) + ], + dtype=torch.int32, + device=model_runner.device) assert torch.equal( attn_metadata.cross_block_tables, expected, @@ -474,7 +492,8 @@ def test_prepare_decode(batch_size): @pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_decode_cuda_graph(batch_size): +@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) +def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): """ Tests that for encoder-decoder models with CUDA Graph capture and replay enabled, the tensors used during the decode phase are correctly padded @@ -489,32 +508,45 @@ def test_prepare_decode_cuda_graph(batch_size): enable_chunked_prefill=False, enforce_eager=False, ) - + block_tables = { + 0: [1], + 1: [3] + } if multiple_seqs_per_seq_group else { + 0: [1] + } seq_lens: List[int] = [] encoder_seq_lens: List[int] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = [] - block_tables = {0: [1]} + cross_block_table = [2] + expanded_batch_size = 0 for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) encoder_seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, - seq_data={0: seq_data}, + seq_data={ + 0: seq_data, + 1: seq_data + } if multiple_seqs_per_seq_group else {0: seq_data}, sampling_params=SamplingParams(temperature=0), block_tables=block_tables, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, ) assert seq_group_metadata.token_chunk_size == 1 + seq_lens.extend( + [seq_len for _ in range(len(seq_group_metadata.seq_data))]) + encoder_seq_lens.extend( + [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) + expanded_batch_size = expanded_batch_size + len( + seq_group_metadata.seq_data) seq_group_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) @@ -530,8 +562,8 @@ def test_prepare_decode_cuda_graph(batch_size): # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. - graph_batch_size = _get_graph_batch_size(batch_size) - cuda_graph_pad_size = graph_batch_size - batch_size + graph_batch_size = _get_graph_batch_size(expanded_batch_size) + cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_encoder_seq_lens = encoder_seq_lens + list( itertools.repeat(1, cuda_graph_pad_size)) @@ -560,10 +592,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify block tables are correct for prompts # - Decoder self-attention. Pad the block tables as expected. - expected = [block_tables[0] for _ in range(batch_size)] - expected.extend([[] for _ in range(cuda_graph_pad_size)]) + flattened_block_tables = [ + block_table for _ in range(len(seq_group_metadata_list)) + for block_table in block_tables.values() + ] + flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)]) expected = make_tensor_with_pad( - expected, + flattened_block_tables, max_len=64, pad=0, dtype=torch.int32, @@ -575,7 +610,10 @@ def test_prepare_decode_cuda_graph(batch_size): ) # - Encoder/decoder cross-attention. Pad the cross-attention block tables # as expected. - expected = [cross_block_table for _ in range(len(seq_group_metadata_list))] + expected = [ + cross_block_table for seq_group_metadata in seq_group_metadata_list + for _ in range(len(seq_group_metadata.seq_data)) + ] expected.extend([[] for _ in range(cuda_graph_pad_size)]) expected = make_tensor_with_pad( expected, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 09dab0135f390..709efdc8b9d57 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -435,18 +435,18 @@ def _prepare_encoder_model_input_tensors( encoder_input_tokens_tensor = self._empty_long_tensor() encoder_input_positions_tensor = self._empty_long_tensor() cross_slot_mapping_tensor = self._empty_long_tensor() - # Extract cross-attention block tables & # seq len from each sequence group metadata. # Cross-attention block tables are empty # during vLLM memory profiling. cross_block_tables = [] for seq_group_metadata in seq_group_metadata_list: - encoder_seq_lens.append( - seq_group_metadata.encoder_seq_data.get_len()) - cross_block_table = seq_group_metadata.cross_block_table - cross_block_tables.append([] if ( - cross_block_table is None) else cross_block_table) + for _ in range(len(seq_group_metadata.seq_data)): + encoder_seq_lens.append( + seq_group_metadata.encoder_seq_data.get_len()) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) if (model_input.attn_metadata is not None and model_input.attn_metadata.use_cuda_graph): From 02c9afa2d04a85269faa2760e9af30527a61d7f6 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Wed, 18 Sep 2024 21:14:28 -0700 Subject: [PATCH 039/116] Revert "[Misc][Bugfix] Disable guided decoding for mistral tokenizer" (#8593) --- .../guided_decoding/__init__.py | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index f4fe8a7307c04..7161e83952a3d 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -6,7 +6,6 @@ from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest) from vllm.sampling_params import LogitsProcessor -from vllm.transformers_utils.tokenizer import MistralTokenizer async def get_guided_decoding_logits_processor( @@ -16,23 +15,12 @@ async def get_guided_decoding_logits_processor( request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': - if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'outlines' is currently not supported " - "for Mistral tokenizer. Please consider contributing to the " - "'outlines' project if you are interested in this feature.") # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( request, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': - if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'lm-format-enforcer' is currently not " - "supported for Mistral tokenizer. Please consider contributing " - "to the 'lm-format-enforcer' project if you are interested " - "in this feature.") from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_lm_format_enforcer_guided_decoding_logits_processor) return await get_lm_format_enforcer_guided_decoding_logits_processor( @@ -49,23 +37,12 @@ def get_local_guided_decoding_logits_processor( # request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': - if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'outlines' is currently not supported " - "for Mistral tokenizer. Please consider contributing to the " - "'outlines' project if you are interested in this feature.") # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_options, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': - if isinstance(tokenizer, MistralTokenizer): - raise NotImplementedError( - "Guided decoding with 'lm-format-enforcer' is currently not " - "supported for Mistral tokenizer. Please consider contributing " - "to the 'lm-format-enforcer' project if you are interested " - "in this feature.") from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( From c52ec5f03471008fa1312d82fb17d40b95a3ca5d Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Wed, 18 Sep 2024 22:24:24 -0700 Subject: [PATCH 040/116] [Bugfix] fixing sonnet benchmark bug in benchmark_serving.py (#8616) --- benchmarks/benchmark_serving.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 3ace910a6cac6..a407a263120bb 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -626,9 +626,9 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) - input_requests = [(prompt, prompt_len, output_len) + input_requests = [(prompt, prompt_len, output_len, None) for prompt, prompt_formatted, prompt_len, - output_len in input_requests] + output_len, _ in input_requests] else: assert ( tokenizer.chat_template or tokenizer.default_chat_template @@ -641,9 +641,9 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) - input_requests = [(prompt_formatted, prompt_len, output_len) + input_requests = [(prompt_formatted, prompt_len, output_len, None) for prompt, prompt_formatted, prompt_len, - output_len in input_requests] + output_len, _ in input_requests] elif args.dataset_name == "hf": input_requests = sample_hf_requests( @@ -963,4 +963,4 @@ def main(args: argparse.Namespace): ) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file From 855c8ae2c9a4085b1ebd66d9a978fb23f47f822c Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 19 Sep 2024 13:33:20 +0800 Subject: [PATCH 041/116] [MISC] remove engine_use_ray in benchmark_throughput.py (#8615) --- benchmarks/benchmark_throughput.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 3f531ee82cc94..e1a5d4ee28ea1 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -191,7 +191,6 @@ async def run_vllm_async( use_v2_block_manager=use_v2_block_manager, disable_async_output_proc=disable_async_output_proc, worker_use_ray=False, - engine_use_ray=False, disable_log_requests=True, ) From 76515f303b44cb3ffc6de63c49148d5081a77119 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 19 Sep 2024 17:51:06 +0100 Subject: [PATCH 042/116] [Frontend] Use MQLLMEngine for embeddings models too (#8584) --- vllm/engine/multiprocessing/__init__.py | 7 +- vllm/engine/multiprocessing/client.py | 106 +++++++++++++++++------- vllm/engine/multiprocessing/engine.py | 23 ++--- 3 files changed, 90 insertions(+), 46 deletions(-) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index ba5c6e15fc821..700332864d17a 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -2,6 +2,7 @@ from enum import Enum from typing import List, Mapping, Optional, Union +from vllm import PoolingParams from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -21,9 +22,9 @@ class MQEngineDeadError(RuntimeError): @dataclass -class RPCGenerateRequest: +class RPCProcessRequest: inputs: PromptInputs - sampling_params: SamplingParams + params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None trace_headers: Optional[Mapping[str, str]] = None @@ -55,7 +56,7 @@ class RPCStartupResponse: tracing_enabled: bool -RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest, +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest, RPCStartupRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 2cb4de79131f1..aa9dbbd448af2 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -11,6 +11,7 @@ from zmq import Frame # type: ignore[attr-defined] from zmq.asyncio import Socket +from vllm import PoolingParams from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block @@ -19,8 +20,8 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCGenerateRequest, - RPCHealthRequest, RPCStartupRequest, + RPCError, RPCHealthRequest, + RPCProcessRequest, RPCStartupRequest, RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT @@ -111,20 +112,8 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig): @staticmethod def is_unsupported_config(engine_args: AsyncEngineArgs): - if engine_args.pipeline_parallel_size > 1: - return True - - is_embedding = ModelConfig( - model=engine_args.model, - revision=engine_args.revision, - tokenizer=engine_args.model, - tokenizer_mode="auto", - trust_remote_code=engine_args.trust_remote_code, - quantization=engine_args.quantization, - seed=0, - dtype="auto").embedding_mode - - return is_embedding + # Pipeline parallel not yet supported + return engine_args.pipeline_parallel_size > 1 @contextmanager def get_data_socket(self) -> Iterator[Socket]: @@ -382,12 +371,9 @@ def errored(self) -> bool: @property def dead_error(self) -> BaseException: - if self._errored_with is not None: - return ENGINE_DEAD_ERROR(self._errored_with) - else: - return ENGINE_DEAD_ERROR() + return ENGINE_DEAD_ERROR(self._errored_with) - async def generate( + def generate( self, inputs: PromptInputs, sampling_params: SamplingParams, @@ -396,6 +382,67 @@ async def generate( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + """ + return self._process_request(inputs, sampling_params, request_id, + lora_request, trace_headers, + prompt_adapter_request) + + def encode( + self, + inputs: PromptInputs, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + """Generate outputs for a request from an embedding model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + + Yields: + The output `EmbeddingRequestOutput` objects from the LLMEngine + for the request. + """ + return self._process_request(inputs, pooling_params, request_id, + lora_request, trace_headers) + + async def _process_request( + self, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ + EmbeddingRequestOutput, None]]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" # If already dead, error out. @@ -410,19 +457,19 @@ async def generate( try: # 2) Detach logits processors so that they can be pickled # separately (may require cloudpickle which is slower) - if sampling_params.logits_processors: + if isinstance(params, SamplingParams) and params.logits_processors: # Defensive shallow copy - sampling_params = copy.copy(sampling_params) - logits_processors = sampling_params.logits_processors - sampling_params.logits_processors = None + params = copy.copy(params) + logits_processors = params.logits_processors + params.logits_processors = None lp_bytes = cloudpickle.dumps(logits_processors) else: lp_bytes = None request_bytes = pickle.dumps( - RPCGenerateRequest( + RPCProcessRequest( inputs=inputs, - sampling_params=sampling_params, + params=params, request_id=request_id, lora_request=lora_request, trace_headers=trace_headers, @@ -452,8 +499,3 @@ async def generate( await self.abort(request_id) finally: self.output_queues.pop(request_id) - - async def encode(self, *args, - **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: - raise NotImplementedError( - "Embeddings not supported with multiprocessing backend") diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 70cd6e5cb6000..f4ca231570853 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -6,7 +6,7 @@ import cloudpickle import zmq -from vllm import AsyncEngineArgs, LLMEngine +from vllm import AsyncEngineArgs, LLMEngine, SamplingParams from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) # yapf conflicts with isort for this block @@ -15,8 +15,8 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCGenerateRequest, - RPCHealthRequest, RPCStartupRequest, + RPCError, RPCHealthRequest, + RPCProcessRequest, RPCStartupRequest, RPCStartupResponse) # yapf: enable from vllm.logger import init_logger @@ -39,8 +39,8 @@ class MQLLMEngine: in concurrnet manner. It runs a background loop and uses zeromq to receive new requests and stream outputs incrementally via ipc. - The :class:`LLMEngine.generate` is kicked off when a new - RPCGenerateRequest is received by the input_socket. + The :class:`LLMEngine` generate or encode process is kicked off when a new + RPCProcessRequest is received by the input_socket. The self.engine_loop checks the input_socket for new requests, adds them to the LLMEngine if there are any, calls the internal @@ -213,12 +213,13 @@ def handle_new_input(self): frames = self.input_socket.recv_multipart(copy=False) request = pickle.loads(frames[0].buffer) - if isinstance(request, RPCGenerateRequest): + if isinstance(request, RPCProcessRequest): if len(frames) > 1: # Use cloudpickle for logits processors + assert isinstance(request.params, SamplingParams) lprocs = cloudpickle.loads(frames[1].buffer) - request.sampling_params.logits_processors = lprocs - self._handle_generate_request(request) + request.params.logits_processors = lprocs + self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) elif isinstance(request, RPCHealthRequest): @@ -231,8 +232,8 @@ def handle_new_input(self): self._send_unhealthy(e) raise e - def _handle_generate_request(self, request: RPCGenerateRequest): - """Handle RPCGenerateRequest by adding it to the LLMEngine.""" + def _handle_process_request(self, request: RPCProcessRequest): + """Handle RPCProcessRequest by adding it to the LLMEngine.""" request_id = request.request_id if self._errored_with is not None: @@ -245,7 +246,7 @@ def _handle_generate_request(self, request: RPCGenerateRequest): self.engine.add_request( request_id=request_id, inputs=request.inputs, - params=request.sampling_params, + params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, prompt_adapter_request=request.prompt_adapter_request) From 9cc373f39036af789fb1ffc1e06b23766996d3f4 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Thu, 19 Sep 2024 12:37:57 -0500 Subject: [PATCH 043/116] [Kernel][Amd] Add fp8 kv cache support for rocm custom paged attention (#8577) --- csrc/rocm/attention.cu | 240 +++++++++++++------- csrc/rocm/ops.h | 3 +- csrc/rocm/torch_bindings.cpp | 3 +- tests/kernels/test_attention.py | 251 ++++++--------------- vllm/_custom_ops.py | 4 +- vllm/attention/backends/rocm_flash_attn.py | 28 +-- 6 files changed, 246 insertions(+), 283 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 8fa7c862fbfa8..b48348a515c8d 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -18,8 +18,11 @@ #include #include #include +#include "cuda_compat.h" #include +#include "../attention/dtype_fp8.cuh" +#include "../quantization/fp8/amd/quant_utils.cuh" #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ defined(__gfx941__) || defined(__gfx942__)) @@ -38,7 +41,6 @@ #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#define WARP_SIZE 64 #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support @@ -60,6 +62,8 @@ typedef struct _B16x8 { _B16x4 xy[2]; } _B16x8; +using _B8x8 = uint2; + ////// Non temporal load stores /////// template @@ -168,18 +172,40 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, } } +template +__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, + const float scale) { + union alignas(16) { + uint4 u4; + _B16x8 u16x8; + vllm::bf16_8_t b16x8; + } tmp; + if constexpr (std::is_same::value) { + tmp.u4 = vllm::fp8::scaled_convert(input, scale); + return tmp.u16x8; + } else if constexpr (std::is_same::value) { + tmp.b16x8 = vllm::fp8::scaled_convert( + input, scale); + return tmp.u16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + /////////////////////////////////////// // grid (num_seqs, num_partitions,num_heads/gqa_ratio) // block (partition size) -template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -192,10 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - #if 0 - scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] - #endif - int max_ctx_blocks) { + int max_ctx_blocks, float k_scale, float v_scale) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -222,12 +245,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( constexpr int x = 16 / sizeof(scalar_t); constexpr int KHELOOP = HEAD_SIZE / x; _B16x8 Klocal[KHELOOP]; + _B8x8 Klocalb8[KHELOOP]; constexpr int VHELOOP = HEAD_SIZE / WARP_SIZE; // v head_size dimension is distributed across lanes constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 // 8xtokens _B16x8 Vlocal[VHELOOP][VTLOOP]; + _B8x8 Vlocalb8[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; #pragma unroll @@ -279,6 +304,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; vphysical_blocks[b] = block_table[vblock_idx_ctx]; } + // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; @@ -298,17 +324,29 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( Qlocal[QHLOOP - 1].xy[1] = {0}; } - const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + - wg_start_kv_head_idx * kv_head_stride; + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride; const int physical_block_offset = local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset // is already cast as _H8 - - const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + } + } else { + constexpr int X = 16 / sizeof(cache_t); + const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; #pragma unroll - for (int d = 0; d < KHELOOP; d++) { - Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + for (int d = 0; d < KHELOOP; d++) { + const int head_elem = d * 8; + const int offset1 = head_elem / X; + const int offset2 = head_elem % X; + const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2; + Klocalb8[d] = *reinterpret_cast(k_ptr3); + } } float alibi_slope[QHLOOP]; @@ -322,30 +360,66 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } - const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; - const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); - // iterate over each v block + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride - const int64_t vphysical_block_number = - static_cast(vphysical_blocks[b]); - const _B16x8* v_ptrh8b = - v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; - // iterate over each head elem (within head_size) + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B16x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + } + } + } + } else { + const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B8x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) #pragma unroll - for (int h = 0; h < VHELOOP; h++) { - const int head_size_elem = h * WARP_SIZE + laneid; - const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; - // iterate over all velems within block + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block #pragma unroll - for (int d = 0; d < BLOCK_SIZE / 8; d++) { - Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + const _B8x8 Vlocalb8 = v_ptrh8be[d]; + Vlocal[h][b * BLOCK_SIZE / 8 + d] = + scaled_convert_b8x8(Vlocalb8, v_scale); + } } } } + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = + scaled_convert_b8x8(Klocalb8[d], k_scale); + } + } + #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], @@ -794,14 +868,16 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -814,10 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - #if 0 - scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] - #endif - int max_ctx_blocks) { + int max_ctx_blocks, float k_scale, float v_scale) { UNREACHABLE_CODE } @@ -839,26 +912,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ + paged_attention_ll4mi_QKV_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks); + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale, v_scale); -template +template void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, - int max_context_len, -#if 0 - torch::Tensor& qk_out, - torch::Tensor& softmax_out, -#endif - const c10::optional& alibi_slopes) { - + int max_context_len, const c10::optional& alibi_slopes, + float k_scale, float v_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -878,14 +949,10 @@ void paged_attention_custom_launcher( float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); -#if 0 - T* qk_out_ptr = reinterpret_cast(qk_out.data_ptr()); - T* softmax_out_ptr = reinterpret_cast(softmax_out.data_ptr()); -#endif const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); const int max_num_partitions = @@ -972,32 +1039,32 @@ void paged_attention_custom_launcher( } } -#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, max_context_len, \ - alibi_slopes); +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes, k_scale, v_scale); -#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ switch (block_size) { \ case 16: \ - CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ break; \ case 32: \ - CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } -#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ switch (head_size) { \ case 64: \ - CALL_CUSTOM_LAUNCHER_BLK(T, 64); \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ break; \ case 128: \ - CALL_CUSTOM_LAUNCHER_BLK(T, 128); \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ break; \ default: \ TORCH_CHECK(false, "Unsupported head size: ", head_size); \ @@ -1020,19 +1087,34 @@ void paged_attention( torch::Tensor& context_lens, // [num_seqs] int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype) { - assert(kv_cache_dtype == "auto"); + const std::string& kv_cache_dtype, double k_scale, double v_scale) { const int head_size = query.size(2); - if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16); + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, + vllm::Fp8KVCacheDataType::kAuto); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, + vllm::Fp8KVCacheDataType::kAuto); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); } } #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 4a07a3f1775bd..9f085115a3956 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -10,4 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& context_lens, int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, double k_scale, + double v_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 082e314587908..a283d4263d293 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -26,7 +26,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " Tensor context_lens, int block_size," " int max_context_len," " Tensor? alibi_slopes," - " str kv_cache_dtype) -> ()"); + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 4bd6f7863a658..ecab512cba16f 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -31,8 +31,7 @@ # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256 - ] if not is_hip() else [64, 80, 96, 112, 128] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] @@ -114,7 +113,8 @@ def ref_single_query_cached_kv_attention( output[i].copy_(out, non_blocking=True) -@pytest.mark.parametrize("version", ["v1", "v2"]) +@pytest.mark.parametrize( + "version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -137,7 +137,8 @@ def test_paged_attention( seed: int, device: str, ) -> None: - if kv_cache_dtype == "fp8" and head_size % 16: + if ((kv_cache_dtype == "fp8" and head_size % 16) + or (version == "rocm" and head_size not in (64, 128))): pytest.skip() seed_everything(seed) @@ -206,7 +207,7 @@ def test_paged_attention( kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0])) - elif version == "v2": + elif version in ("v2", "rocm"): num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -219,32 +220,61 @@ def test_paged_attention( dtype=torch.float32, ) max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - ) - - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, key_cache, - value_cache, num_kv_heads, scale, block_tables, seq_lens, - block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0])) + if version == "v2": + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0])) + + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._rocm_C.paged_attention, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale), + cond=(head_size == HEAD_SIZES[0])) else: raise AssertionError(f"Unknown version: {version}") @@ -328,162 +358,6 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) -@pytest.mark.parametrize("version", ["rocm"]) -@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128 -@pytest.mark.parametrize("use_alibi", USE_ALIBI) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", ["auto"]) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(not is_hip(), reason="only for rocm") -def test_paged_attention_rocm( - kv_cache_factory, - version: str, - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - use_alibi: bool, - block_size: int, - dtype: torch.dtype, - kv_cache_dtype: str, - seed: int, - device: str, -) -> None: - seed_everything(seed) - torch.set_default_device(device) - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) - query.uniform_(-scale, scale) - - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None - if use_alibi: - alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - context_lens[-1] = MAX_SEQ_LEN - #context_lens = [8192 for _ in range(num_seqs)] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int) - #print('>>> ctx lens', context_lens) - - # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int) - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) - key_cache, value_cache = key_caches[0], value_caches[0] - - # TODO(charlifu) enable fp8 kv cache - # Using default kv_scale - # kv_scale = 1.0 - - # Call the paged attention kernel. - output = torch.empty_like(query) - PARTITION_SIZE_ROCM = 256 - num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) // - PARTITION_SIZE_ROCM) - assert PARTITION_SIZE_ROCM % block_size == 0 - num_seqs, num_heads, head_size = output.shape - tmp_output = torch.empty( - size=(num_seqs, num_heads, num_partitions, head_size), - dtype=output.dtype, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, num_partitions), - dtype=torch.float32, - ) - max_logits = torch.empty_like(exp_sums) - if version == "rocm": - ops.paged_attention_rocm( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - kv_cache_dtype, - ) - else: - raise AssertionError(f"Unknown version: {version}") - - # Run the reference implementation. - if kv_cache_dtype == "fp8": - # Convert cache data back to dtype. - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, - block_size, x) - dequantized_key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(key_cache, dequantized_key_cache) - key_cache = dequantized_key_cache - - value_cache_shape = value_cache.shape - dequantized_value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(value_cache, dequantized_value_cache) - value_cache = dequantized_value_cache - - ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - context_lens, - scale, - alibi_slopes, - ) - - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. - atol = get_default_atol(output) if is_hip() else 1e-3 - rtol = get_default_rtol(output) if is_hip() else 1e-5 - - # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, - # so we use a relaxed tolerance for the test. - atol, rtol = 1e-4, 1e-5 - if dtype == torch.bfloat16: - atol, rtol = 2e-4, 1e-5 - if use_alibi: - if dtype == torch.half: - atol, rtol = 5e-4, 1e-5 - if dtype == torch.bfloat16: - atol, rtol = 1e-3, 1e-5 - if kv_cache_dtype == "fp8": - atol, rtol = 1e-2, 1e-5 - assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) - - # TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -491,7 +365,8 @@ def test_paged_attention_rocm( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(is_hip(), reason="skip for rocm") +@pytest.mark.skipif(is_hip(), + reason="Xformers backend is not supported on ROCm.") @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ff5aa8bee3c27..678700055c992 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -146,12 +146,14 @@ def paged_attention_rocm( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, + k_scale: float, + v_scale: float, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype) + kv_cache_dtype, k_scale, v_scale) # pos encoding ops diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6bd276ade1d41..70e6857584ace 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -17,8 +17,8 @@ logger = init_logger(__name__) -_PARTITION_SIZE = 256 -ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName +_PARTITION_SIZE_ROCM = 512 +_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName class ROCmFlashAttentionBackend(AttentionBackend): @@ -489,14 +489,15 @@ def forward( num_seqs, num_heads, head_size = decode_query.shape block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads - use_custom = use_rocm_custom_paged_attention( - decode_query.dtype, head_size, block_size, self.kv_cache_dtype, - gqa_ratio, decode_meta.max_decode_seq_len) + use_custom = _use_rocm_custom_paged_attention( + decode_query.dtype, head_size, block_size, gqa_ratio, + decode_meta.max_decode_seq_len) if use_custom: max_seq_len = decode_meta.max_decode_seq_len - max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) - assert _PARTITION_SIZE % block_size == 0 + max_num_partitions = ( + (max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 tmp_output = torch.empty( size=(num_seqs, num_heads, max_num_partitions, head_size), dtype=output.dtype, @@ -524,6 +525,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, + k_scale, + v_scale, ) else: output[num_prefill_tokens:] = PagedAttention.forward_decode( @@ -580,12 +583,11 @@ def _sdpa_attention( return output -def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, - block_size: int, kv_cache_dtype: str, - gqa_ratio: int, max_seq_len: int) -> bool: +def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, gqa_ratio: int, + max_seq_len: int) -> bool: # rocm custom page attention not support on navi (gfx1*) - return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) + return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and kv_cache_dtype == "auto" and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) From e42c634acbd1b86b5becca51e8b8108a32a438d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=8F=E4=B8=80?= Date: Fri, 20 Sep 2024 02:28:25 +0800 Subject: [PATCH 044/116] [Core] simplify logits resort in _apply_top_k_top_p (#8619) --- vllm/model_executor/layers/sampler.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 487f5a3d2a441..2ca86a4653cf4 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -433,12 +433,9 @@ def _apply_top_k_top_p( logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. - src = torch.arange(logits_idx.shape[-1], - device=logits_idx.device).expand_as(logits_idx) - logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1, - index=logits_idx, - src=src) - logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv) + logits = torch.empty_like(logits_sort).scatter_(dim=-1, + index=logits_idx, + src=logits_sort) return logits From ea4647b7d77c4738c5ed2ab77a2c9f5ad335f6fb Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 20 Sep 2024 03:15:55 +0800 Subject: [PATCH 045/116] [Doc] Add documentation for GGUF quantization (#8618) --- docs/source/index.rst | 1 + docs/source/quantization/gguf.rst | 73 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 docs/source/quantization/gguf.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 4b817c4ba9498..79f723eace762 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -107,6 +107,7 @@ Documentation quantization/supported_hardware quantization/auto_awq quantization/bnb + quantization/gguf quantization/int8 quantization/fp8 quantization/fp8_e5m2_kvcache diff --git a/docs/source/quantization/gguf.rst b/docs/source/quantization/gguf.rst new file mode 100644 index 0000000000000..9f00dc5563909 --- /dev/null +++ b/docs/source/quantization/gguf.rst @@ -0,0 +1,73 @@ +.. _gguf: + +GGUF +================== + +.. warning:: + + Please note that GGUF support in vLLM is highly experimental and under-optimized at the moment, it might be incompatible with other features. Currently, you can use GGUF as a way to reduce memory footprint. If you encounter any issues, please report them to the vLLM team. + +.. warning:: + + Currently, vllm only supports loading single-file GGUF models. If you have a multi-files GGUF model, you can use `gguf-split `_ tool to merge them to a single-file model. + +To run a GGUF model with vLLM, you can download and use the local GGUF model from `TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF `_ with the following command: + +.. code-block:: console + + $ wget https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf + $ # We recommend using the tokenizer from base model to avoid long-time and buggy tokenizer conversion. + $ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 + +You can also add ``--tensor-parallel-size 2`` to enable tensor parallelism inference with 2 GPUs: + +.. code-block:: console + + $ # We recommend using the tokenizer from base model to avoid long-time and buggy tokenizer conversion. + $ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tensor-parallel-size 2 + +.. warning:: + + We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size. + +You can also use the GGUF model directly through the LLM entrypoint: + +.. code-block:: python + + from vllm import LLM, SamplingParams + + # In this script, we demonstrate how to pass input to the chat method: + conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + # Create an LLM. + llm = LLM(model="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", + tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.chat(conversation, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") From 9e99407e3ccbb290bae77af230da38c70a52a055 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 19 Sep 2024 12:16:28 -0700 Subject: [PATCH 046/116] Create SECURITY.md (#8642) --- SECURITY.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000000..d9a392158472d --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,12 @@ +# Security Policy + +## Reporting a Vulnerability + +If you believe you have found a security vulnerability in vLLM, we encourage you to let us know right away. +We will investigate all legitimate reports and do our best to quickly fix the problem. + +Please report security issues using https://github.com/vllm-project/vllm/security/advisories/new + +--- +Please see PyTorch Security for more information how to securely interact with models: https://github.com/pytorch/pytorch/blob/main/SECURITY.md +This document mostly references the recommendation from PyTorch, thank you! From 6cb748e190a94e20987314025614b8bd806602f2 Mon Sep 17 00:00:00 2001 From: "Alexey Kondratiev(AMD)" <143633163+alexeykondrat@users.noreply.github.com> Date: Thu, 19 Sep 2024 16:06:32 -0400 Subject: [PATCH 047/116] [CI/Build] Re-enabling Entrypoints tests on ROCm, excluding ones that fail (#8551) --- .buildkite/run-amd-test.sh | 9 +++++++++ .buildkite/test-pipeline.yaml | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 9274a30e04325..45b20c9447c7d 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -94,6 +94,15 @@ if [[ $commands == *" kernels "* ]]; then --ignore=kernels/test_sampler.py" fi +#ignore certain Entrypoints tests +if [[ $commands == *" entrypoints/openai "* ]]; then + commands=${commands//" entrypoints/openai "/" entrypoints/openai \ + --ignore=entrypoints/openai/test_accuracy.py \ + --ignore=entrypoints/openai/test_audio.py \ + --ignore=entrypoints/openai/test_encoder_decoder.py \ + --ignore=entrypoints/openai/test_oot_registration.py "} +fi + PARALLEL_JOB_COUNT=8 # check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. if [[ $commands == *"--shard-id="* ]]; then diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 37207b677a1ee..379a67c4c8cf8 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -84,7 +84,7 @@ steps: - label: Entrypoints Test # 20min working_dir: "/vllm-workspace/tests" fast_check: true - #mirror_hardwares: [amd] + mirror_hardwares: [amd] source_file_dependencies: - vllm/ commands: From de6f90a13d7b98c4958ba107ec16cb6f95efb10f Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Thu, 19 Sep 2024 18:36:30 -0400 Subject: [PATCH 048/116] [Misc] guard against change in cuda library name (#8609) --- cmake/utils.cmake | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 1ea6d2b0f090e..730517a20129a 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -350,13 +350,14 @@ function (define_gpu_extension_target GPU_MOD_NAME) target_include_directories(${GPU_MOD_NAME} PRIVATE csrc ${GPU_INCLUDE_DIRECTORIES}) - # TODO: is torch_python_LIBRARY needed? - target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY} - ${GPU_LIBRARIES}) + target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES}) # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of # dependencies that are not necessary and may not be installed. if (GPU_LANGUAGE STREQUAL "CUDA") + if ("${CUDA_CUDA_LIB}" STREQUAL "") + set(CUDA_CUDA_LIB "${CUDA_CUDA_LIBRARY}") + endif() target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB} ${CUDA_LIBRARIES}) else() From 18ae428a0d8792d160d811a9cd5bb004d68ea8bd Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Thu, 19 Sep 2024 17:54:02 -0700 Subject: [PATCH 049/116] [Bugfix] Fix Phi3.5 mini and MoE LoRA inference (#8571) --- vllm/model_executor/models/__init__.py | 2 +- vllm/model_executor/models/phi3.py | 17 +++++++++++++++++ vllm/model_executor/models/phimoe.py | 4 ++++ 3 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/phi3.py diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 591007e787f47..7427060922281 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -50,7 +50,7 @@ "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), - "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), + "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), diff --git a/vllm/model_executor/models/phi3.py b/vllm/model_executor/models/phi3.py new file mode 100644 index 0000000000000..02b2ff01c3832 --- /dev/null +++ b/vllm/model_executor/models/phi3.py @@ -0,0 +1,17 @@ +# coding=utf-8 +# Adapted from llama.py +"""Inference-only Phi3 model code inherit from Llama.py""" + +from vllm.model_executor.models.llama import LlamaForCausalLM + + +class Phi3ForCausalLM(LlamaForCausalLM): + + packed_modules_mapping = { + "qkv_proj": [ + "qkv_proj", + ], + "gate_up_proj": [ + "gate_up_proj", + ], + } diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 5036f55803c20..a3555a294bb66 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -491,6 +491,10 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA): "o_proj", "embed_tokens", "lm_head", + "w1", + "w2", + "w3", + "gate", ] embedding_modules = { "embed_tokens": "input_embeddings", From 9e5ec35b1f8239453b1aaab28e7a02307db4ab1f Mon Sep 17 00:00:00 2001 From: William Lin Date: Thu, 19 Sep 2024 20:49:54 -0700 Subject: [PATCH 050/116] [bugfix] [AMD] add multi-step advance_step to ROCmFlashAttentionMetadata (#8474) --- vllm/attention/backends/rocm_flash_attn.py | 58 +++++++++++++++++++++- vllm/worker/multi_step_model_runner.py | 2 +- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 70e6857584ace..5560f44be4196 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,6 +1,6 @@ """Attention layer ROCm GPUs.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch @@ -15,6 +15,9 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 512 @@ -180,6 +183,59 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: ) return self._cached_decode_metadata + def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + class ROCmFlashAttentionMetadataBuilder( CommonMetadataBuilder[ROCmFlashAttentionMetadata]): diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index ebcafbbab119a..c7295f872f70f 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -29,7 +29,7 @@ logger = init_logger(__name__) -MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"] +MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"] def seq_output_builder(): From 260d40b5ea48df9421325388abcc8d907a560fc5 Mon Sep 17 00:00:00 2001 From: Jiaxin Shan Date: Thu, 19 Sep 2024 23:20:56 -0700 Subject: [PATCH 051/116] [Core] Support Lora lineage and base model metadata management (#6315) --- docs/source/models/lora.rst | 64 +++++++++++++ tests/entrypoints/openai/test_cli_args.py | 91 +++++++++++++++++++ tests/entrypoints/openai/test_lora_lineage.py | 83 +++++++++++++++++ tests/entrypoints/openai/test_models.py | 6 +- tests/entrypoints/openai/test_serving_chat.py | 6 +- .../entrypoints/openai/test_serving_engine.py | 5 +- vllm/entrypoints/openai/api_server.py | 14 ++- vllm/entrypoints/openai/cli_args.py | 27 +++++- vllm/entrypoints/openai/run_batch.py | 9 +- vllm/entrypoints/openai/serving_chat.py | 11 ++- vllm/entrypoints/openai/serving_completion.py | 9 +- vllm/entrypoints/openai/serving_embedding.py | 6 +- vllm/entrypoints/openai/serving_engine.py | 43 ++++++--- .../openai/serving_tokenization.py | 7 +- vllm/lora/request.py | 1 + 15 files changed, 337 insertions(+), 45 deletions(-) create mode 100644 tests/entrypoints/openai/test_cli_args.py create mode 100644 tests/entrypoints/openai/test_lora_lineage.py diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index b3821ebdfceca..ef0177eaf2162 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -159,3 +159,67 @@ Example request to unload a LoRA adapter: -d '{ "lora_name": "sql_adapter" }' + + +New format for `--lora-modules` +------------------------------- + +In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example: + +.. code-block:: bash + + --lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/ + +This would only include the `name` and `path` for each LoRA module, but did not provide a way to specify a `base_model_name`. +Now, you can specify a base_model_name alongside the name and path using JSON format. For example: + +.. code-block:: bash + + --lora-modules '{"name": "sql-lora", "path": "/path/to/lora", "base_model_name": "meta-llama/Llama-2-7b"}' + +To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case. + + +Lora model lineage in model card +-------------------------------- + +The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this: + +- The `parent` field of LoRA model `sql-lora` now links to its base model `meta-llama/Llama-2-7b-hf`. This correctly reflects the hierarchical relationship between the base model and the LoRA adapter. +- The `root` field points to the artifact location of the lora adapter. + +.. code-block:: bash + + $ curl http://localhost:8000/v1/models + + { + "object": "list", + "data": [ + { + "id": "meta-llama/Llama-2-7b-hf", + "object": "model", + "created": 1715644056, + "owned_by": "vllm", + "root": "~/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9/", + "parent": null, + "permission": [ + { + ..... + } + ] + }, + { + "id": "sql-lora", + "object": "model", + "created": 1715644056, + "owned_by": "vllm", + "root": "~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/", + "parent": meta-llama/Llama-2-7b-hf, + "permission": [ + { + .... + } + ] + } + ] + } diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py new file mode 100644 index 0000000000000..8ee7fb8b2c6bf --- /dev/null +++ b/tests/entrypoints/openai/test_cli_args.py @@ -0,0 +1,91 @@ +import json +import unittest + +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.serving_engine import LoRAModulePath +from vllm.utils import FlexibleArgumentParser + +LORA_MODULE = { + "name": "module2", + "path": "/path/to/module2", + "base_model_name": "llama" +} + + +class TestLoraParserAction(unittest.TestCase): + + def setUp(self): + # Setting up argparse parser for tests + parser = FlexibleArgumentParser( + description="vLLM's remote OpenAI server.") + self.parser = make_arg_parser(parser) + + def test_valid_key_value_format(self): + # Test old format: name=path + args = self.parser.parse_args([ + '--lora-modules', + 'module1=/path/to/module1', + ]) + expected = [LoRAModulePath(name='module1', path='/path/to/module1')] + self.assertEqual(args.lora_modules, expected) + + def test_valid_json_format(self): + # Test valid JSON format input + args = self.parser.parse_args([ + '--lora-modules', + json.dumps(LORA_MODULE), + ]) + expected = [ + LoRAModulePath(name='module2', + path='/path/to/module2', + base_model_name='llama') + ] + self.assertEqual(args.lora_modules, expected) + + def test_invalid_json_format(self): + # Test invalid JSON format input, missing closing brace + with self.assertRaises(SystemExit): + self.parser.parse_args([ + '--lora-modules', + '{"name": "module3", "path": "/path/to/module3"' + ]) + + def test_invalid_type_error(self): + # Test type error when values are not JSON or key=value + with self.assertRaises(SystemExit): + self.parser.parse_args([ + '--lora-modules', + 'invalid_format' # This is not JSON or key=value format + ]) + + def test_invalid_json_field(self): + # Test valid JSON format but missing required fields + with self.assertRaises(SystemExit): + self.parser.parse_args([ + '--lora-modules', + '{"name": "module4"}' # Missing required 'path' field + ]) + + def test_empty_values(self): + # Test when no LoRA modules are provided + args = self.parser.parse_args(['--lora-modules', '']) + self.assertEqual(args.lora_modules, []) + + def test_multiple_valid_inputs(self): + # Test multiple valid inputs (both old and JSON format) + args = self.parser.parse_args([ + '--lora-modules', + 'module1=/path/to/module1', + json.dumps(LORA_MODULE), + ]) + expected = [ + LoRAModulePath(name='module1', path='/path/to/module1'), + LoRAModulePath(name='module2', + path='/path/to/module2', + base_model_name='llama') + ] + self.assertEqual(args.lora_modules, expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/entrypoints/openai/test_lora_lineage.py b/tests/entrypoints/openai/test_lora_lineage.py new file mode 100644 index 0000000000000..ab39684c2f31a --- /dev/null +++ b/tests/entrypoints/openai/test_lora_lineage.py @@ -0,0 +1,83 @@ +import json + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +# downloading lora to test lora requests +from huggingface_hub import snapshot_download + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +# technically this needs Mistral-7B-v0.1 as base, but we're not testing +# generation quality here +LORA_NAME = "typeof/zephyr-7b-beta-lora" + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def server_with_lora_modules_json(zephyr_lora_files): + # Define the json format LoRA module configurations + lora_module_1 = { + "name": "zephyr-lora", + "path": zephyr_lora_files, + "base_model_name": MODEL_NAME + } + + lora_module_2 = { + "name": "zephyr-lora2", + "path": zephyr_lora_files, + "base_model_name": MODEL_NAME + } + + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + # lora config below + "--enable-lora", + "--lora-modules", + json.dumps(lora_module_1), + json.dumps(lora_module_2), + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + "--max-num-seqs", + "64", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_for_lora_lineage(server_with_lora_modules_json): + async with server_with_lora_modules_json.get_async_client( + ) as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI, + zephyr_lora_files): + models = await client_for_lora_lineage.models.list() + models = models.data + served_model = models[0] + lora_models = models[1:] + assert served_model.id == MODEL_NAME + assert served_model.root == MODEL_NAME + assert served_model.parent is None + assert all(lora_model.root == zephyr_lora_files + for lora_model in lora_models) + assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) + assert lora_models[0].id == "zephyr-lora" + assert lora_models[1].id == "zephyr-lora2" diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index 5cd570f43e1a7..ae5bf404d3d2b 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -51,12 +51,14 @@ async def client(server): @pytest.mark.asyncio -async def test_check_models(client: openai.AsyncOpenAI): +async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files): models = await client.models.list() models = models.data served_model = models[0] lora_models = models[1:] assert served_model.id == MODEL_NAME - assert all(model.root == MODEL_NAME for model in models) + assert served_model.root == MODEL_NAME + assert all(lora_model.root == zephyr_lora_files + for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" assert lora_models[1].id == "zephyr-lora2" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index de2a932199a01..db31745cc102e 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -7,10 +7,12 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] @dataclass @@ -37,7 +39,7 @@ async def _async_serving_chat_init(): serving_completion = OpenAIServingChat(engine, model_config, - served_model_names=[MODEL_NAME], + BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, lora_modules=None, @@ -58,7 +60,7 @@ def test_serving_chat_should_set_correct_max_tokens(): serving_chat = OpenAIServingChat(mock_engine, MockModelConfig(), - served_model_names=[MODEL_NAME], + BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, lora_modules=None, diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 6d9e620b4af7d..6199a75b5b4f8 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -8,9 +8,10 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing MODEL_NAME = "meta-llama/Llama-2-7b" +BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] LORA_LOADING_SUCCESS_MESSAGE = ( "Success: LoRA adapter '{lora_name}' added successfully.") LORA_UNLOADING_SUCCESS_MESSAGE = ( @@ -25,7 +26,7 @@ async def _async_serving_engine_init(): serving_engine = OpenAIServing(mock_engine_client, mock_model_config, - served_model_names=[MODEL_NAME], + BASE_MODEL_PATHS, lora_modules=None, prompt_adapters=None, request_logger=None) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index fd6f36e8768dd..5078a2654eb22 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -50,6 +50,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.logger import init_logger @@ -476,13 +477,18 @@ def init_app_state( else: request_logger = RequestLogger(max_log_len=args.max_log_len) + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, - served_model_names, + base_model_paths, args.response_role, lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, @@ -494,7 +500,7 @@ def init_app_state( state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, - served_model_names, + base_model_paths, lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, request_logger=request_logger, @@ -503,13 +509,13 @@ def init_app_state( state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, - served_model_names, + base_model_paths, request_logger=request_logger, ) state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, - served_model_names, + base_model_paths, lora_modules=args.lora_modules, request_logger=request_logger, chat_template=args.chat_template, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index bbb0823de9a51..9d3071a97fbe6 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -31,8 +31,23 @@ def __call__( lora_list: List[LoRAModulePath] = [] for item in values: - name, path = item.split('=') - lora_list.append(LoRAModulePath(name, path)) + if item in [None, '']: # Skip if item is None or empty string + continue + if '=' in item and ',' not in item: # Old format: name=path + name, path = item.split('=') + lora_list.append(LoRAModulePath(name, path)) + else: # Assume JSON format + try: + lora_dict = json.loads(item) + lora = LoRAModulePath(**lora_dict) + lora_list.append(lora) + except json.JSONDecodeError: + parser.error( + f"Invalid JSON format for --lora-modules: {item}") + except TypeError as e: + parser.error( + f"Invalid fields for --lora-modules: {item} - {str(e)}" + ) setattr(namespace, self.dest, lora_list) @@ -95,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, nargs='+', action=LoRAParserAction, - help="LoRA module configurations in the format name=path. " - "Multiple modules can be specified.") + help="LoRA module configurations in either 'name=path' format" + "or JSON format. " + "Example (old format): 'name=path' " + "Example (new format): " + "'{\"name\": \"name\", \"local_path\": \"path\", " + "\"base_model_name\": \"id\"}'") parser.add_argument( "--prompt-adapters", type=nullable_str, diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index b745410fe6b3b..f5249a0c447b3 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -20,6 +20,7 @@ # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION @@ -196,6 +197,10 @@ async def main(args): engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) model_config = await engine.get_model_config() + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] if args.disable_log_requests: request_logger = None @@ -206,7 +211,7 @@ async def main(args): openai_serving_chat = OpenAIServingChat( engine, model_config, - served_model_names, + base_model_paths, args.response_role, lora_modules=None, prompt_adapters=None, @@ -216,7 +221,7 @@ async def main(args): openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, - served_model_names, + base_model_paths, request_logger=request_logger, ) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b84898dc39b0f..1ee4b3ce17cfa 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -23,7 +23,8 @@ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, OpenAIServing, PromptAdapterPath, TextTokensPrompt) @@ -47,7 +48,7 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], response_role: str, *, lora_modules: Optional[List[LoRAModulePath]], @@ -59,7 +60,7 @@ def __init__(self, tool_parser: Optional[str] = None): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=lora_modules, prompt_adapters=prompt_adapters, request_logger=request_logger, @@ -262,7 +263,7 @@ async def chat_completion_stream_generator( conversation: List[ConversationMessage], tokenizer: AnyTokenizer, ) -> AsyncGenerator[str, None]: - model_name = self.served_model_names[0] + model_name = self.base_model_paths[0].name created_time = int(time.time()) chunk_object_type: Final = "chat.completion.chunk" first_iteration = True @@ -596,7 +597,7 @@ async def chat_completion_full_generator( tokenizer: AnyTokenizer, ) -> Union[ErrorResponse, ChatCompletionResponse]: - model_name = self.served_model_names[0] + model_name = self.base_model_paths[0].name created_time = int(time.time()) final_res: Optional[RequestOutput] = None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 14fa60243c584..9abd74d0561d0 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -20,7 +20,8 @@ CompletionStreamResponse, ErrorResponse, UsageInfo) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, OpenAIServing, PromptAdapterPath) from vllm.logger import init_logger @@ -45,7 +46,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], @@ -54,7 +55,7 @@ def __init__( ): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=lora_modules, prompt_adapters=prompt_adapters, request_logger=request_logger, @@ -89,7 +90,7 @@ async def create_completion( return self.create_error_response( "suffix is not currently supported") - model_name = self.served_model_names[0] + model_name = self.base_model_paths[0].name request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index f111a3a8277b5..5d95e1369b884 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -14,7 +14,7 @@ EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.logger import init_logger from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput from vllm.utils import merge_async_iterators, random_uuid @@ -73,13 +73,13 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, request_logger: Optional[RequestLogger], ): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=None, prompt_adapters=None, request_logger=request_logger) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 72f9381abc7db..9c4e8d8bb671a 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -39,6 +39,12 @@ logger = init_logger(__name__) +@dataclass +class BaseModelPath: + name: str + model_path: str + + @dataclass class PromptAdapterPath: name: str @@ -49,6 +55,7 @@ class PromptAdapterPath: class LoRAModulePath: name: str path: str + base_model_name: Optional[str] = None AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, @@ -66,7 +73,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], @@ -79,17 +86,20 @@ def __init__( self.model_config = model_config self.max_model_len = model_config.max_model_len - self.served_model_names = served_model_names + self.base_model_paths = base_model_paths self.lora_id_counter = AtomicCounter(0) self.lora_requests = [] if lora_modules is not None: self.lora_requests = [ - LoRARequest( - lora_name=lora.name, - lora_int_id=i, - lora_path=lora.path, - ) for i, lora in enumerate(lora_modules, start=1) + LoRARequest(lora_name=lora.name, + lora_int_id=i, + lora_path=lora.path, + base_model_name=lora.base_model_name + if lora.base_model_name + and self._is_model_supported(lora.base_model_name) + else self.base_model_paths[0].name) + for i, lora in enumerate(lora_modules, start=1) ] self.prompt_adapter_requests = [] @@ -112,21 +122,23 @@ def __init__( async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ - ModelCard(id=served_model_name, + ModelCard(id=base_model.name, max_model_len=self.max_model_len, - root=self.served_model_names[0], + root=base_model.model_path, permission=[ModelPermission()]) - for served_model_name in self.served_model_names + for base_model in self.base_model_paths ] lora_cards = [ ModelCard(id=lora.lora_name, - root=self.served_model_names[0], + root=lora.local_path, + parent=lora.base_model_name if lora.base_model_name else + self.base_model_paths[0].name, permission=[ModelPermission()]) for lora in self.lora_requests ] prompt_adapter_cards = [ ModelCard(id=prompt_adapter.prompt_adapter_name, - root=self.served_model_names[0], + root=self.base_model_paths[0].name, permission=[ModelPermission()]) for prompt_adapter in self.prompt_adapter_requests ] @@ -169,7 +181,7 @@ async def _check_model( self, request: AnyRequest, ) -> Optional[ErrorResponse]: - if request.model in self.served_model_names: + if self._is_model_supported(request.model): return None if request.model in [lora.lora_name for lora in self.lora_requests]: return None @@ -187,7 +199,7 @@ def _maybe_get_adapters( self, request: AnyRequest ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[ None, PromptAdapterRequest]]: - if request.model in self.served_model_names: + if self._is_model_supported(request.model): return None, None for lora in self.lora_requests: if request.model == lora.lora_name: @@ -480,3 +492,6 @@ async def unload_lora_adapter( if lora_request.lora_name != lora_name ] return f"Success: LoRA adapter '{lora_name}' removed successfully." + + def _is_model_supported(self, model_name): + return any(model.name == model_name for model in self.base_model_paths) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 8f8862897fc4e..6d9a1ae088079 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -16,7 +16,8 @@ TokenizeRequest, TokenizeResponse) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, OpenAIServing) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import MistralTokenizer @@ -31,7 +32,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - served_model_names: List[str], + base_model_paths: List[BaseModelPath], *, lora_modules: Optional[List[LoRAModulePath]], request_logger: Optional[RequestLogger], @@ -39,7 +40,7 @@ def __init__( ): super().__init__(engine_client=engine_client, model_config=model_config, - served_model_names=served_model_names, + base_model_paths=base_model_paths, lora_modules=lora_modules, prompt_adapters=None, request_logger=request_logger) diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 47a59d80d3a45..c4b26dc92c6f4 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -28,6 +28,7 @@ class LoRARequest( lora_path: str = "" lora_local_path: Optional[str] = msgspec.field(default=None) long_lora_max_len: Optional[int] = None + base_model_name: Optional[str] = msgspec.field(default=None) def __post_init__(self): if 'lora_local_path' in self.__struct_fields__: From 3b63de9353ce51ba6c1c167ae8d4b87b8bcf9c9e Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Fri, 20 Sep 2024 09:31:41 -0700 Subject: [PATCH 052/116] [Model] Add OLMoE (#7922) --- docs/source/models/supported_models.rst | 4 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/olmoe.py | 409 ++++++++++++++++++++++++ 3 files changed, 414 insertions(+) create mode 100644 vllm/model_executor/models/olmoe.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 745b4b8e2e0eb..9e0303e1dab6c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -127,6 +127,10 @@ Decoder-only Language Models - Nemotron-3, Nemotron-4, Minitron - :code:`nvidia/Minitron-8B-Base`, :code:`mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. - ✅︎ + * - :code:`OLMoEForCausalLM` + - OLMoE + - :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc. + - * - :code:`OLMoForCausalLM` - OLMo - :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 7427060922281..bee312a14f440 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -46,6 +46,7 @@ "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), + "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py new file mode 100644 index 0000000000000..c76e5e86c89d8 --- /dev/null +++ b/vllm/model_executor/models/olmoe.py @@ -0,0 +1,409 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMoE model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import print_warning_once + + +class OlmoeMoE(nn.Module): + """A tensor-parallel MoE implementation for Olmoe that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + quant_config=None) + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + return final_hidden_states.view(orig_shape) + + +class OlmoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 4096, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.q_norm = RMSNorm(hidden_size, eps=1e-5) + self.k_norm = RMSNorm(hidden_size, eps=1e-5) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 4096) + + self.self_attn = OlmoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + + self.mlp = OlmoeMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class OlmoeModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + OlmoeDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], attn_metadata, + residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class OlmoeForCausalLM(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = OlmoeModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 2940afa04e39fa9f248c565687d9a2acf7401355 Mon Sep 17 00:00:00 2001 From: "Alexey Kondratiev(AMD)" <143633163+alexeykondrat@users.noreply.github.com> Date: Fri, 20 Sep 2024 13:27:44 -0400 Subject: [PATCH 053/116] [CI/Build] Removing entrypoints/openai/test_embedding.py test from ROCm build (#8670) --- .buildkite/run-amd-test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 45b20c9447c7d..df201cdc7c554 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -100,6 +100,7 @@ if [[ $commands == *" entrypoints/openai "* ]]; then --ignore=entrypoints/openai/test_accuracy.py \ --ignore=entrypoints/openai/test_audio.py \ --ignore=entrypoints/openai/test_encoder_decoder.py \ + --ignore=entrypoints/openai/test_embedding.py \ --ignore=entrypoints/openai/test_oot_registration.py "} fi From b28298f2f4bd4ec6d1020c10b923a9eb7993dc89 Mon Sep 17 00:00:00 2001 From: saumya-saran Date: Fri, 20 Sep 2024 12:46:02 -0700 Subject: [PATCH 054/116] [Bugfix] Validate SamplingParam n is an int (#8548) --- vllm/sampling_params.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 5edbc8e424e81..86e80ae5e224d 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -273,9 +273,14 @@ def __post_init__(self) -> None: self._all_stop_token_ids = set(self.stop_token_ids) def _verify_args(self) -> None: + if not isinstance(self.n, int): + raise ValueError(f"n must be an int, but is of " + f"type {type(self.n)}") if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") - assert isinstance(self.best_of, int) + if not isinstance(self.best_of, int): + raise ValueError(f'best_of must be an int, but is of ' + f'type {type(self.best_of)}') if self.best_of < self.n: raise ValueError(f"best_of must be greater than or equal to n, " f"got n={self.n} and best_of={self.best_of}.") From 035fa895ecedea87810889aabbe50ba8a2ad7d5d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 21 Sep 2024 04:52:19 +0800 Subject: [PATCH 055/116] [Misc] Show AMD GPU topology in `collect_env.py` (#8649) --- collect_env.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/collect_env.py b/collect_env.py index 839d54172e775..c5cd8c315e749 100644 --- a/collect_env.py +++ b/collect_env.py @@ -285,9 +285,14 @@ def summarize_vllm_build_flags(): def get_gpu_topo(run_lambda): + output = None + if get_platform() == 'linux': - return run_and_read_all(run_lambda, 'nvidia-smi topo -m') - return None + output = run_and_read_all(run_lambda, 'nvidia-smi topo -m') + if output is None: + output = run_and_read_all(run_lambda, 'rocm-smi --showtopo') + + return output # example outputs of CPU infos From 2874bac618052a079efd837fc82cf3f3519079c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pastel=EF=BC=81?= <1627301104@qq.com> Date: Sat, 21 Sep 2024 05:00:45 +0800 Subject: [PATCH 056/116] [Bugfix] Config got an unexpected keyword argument 'engine' (#8556) --- vllm/entrypoints/api_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 6127177b4d889..f3e80cab62a34 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -121,7 +121,6 @@ async def run_server(args: Namespace, shutdown_task = await serve_http( app, - engine=engine, host=args.host, port=args.port, log_level=args.log_level, From b4e4eda92e1d3a013fc4007db64b69d8604264ff Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 20 Sep 2024 23:33:03 +0200 Subject: [PATCH 057/116] [Bugfix][Core] Fix tekken edge case for mistral tokenizer (#8640) --- .../decoder_only/language/test_mistral.py | 26 ++++++++++++++- vllm/transformers_utils/tokenizers/mistral.py | 32 +++++++++++++++++-- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 26f90456849f1..174b905d9cbb9 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -4,7 +4,7 @@ """ import pytest -from vllm import SamplingParams +from vllm import LLM, SamplingParams from ...utils import check_logprobs_close @@ -16,6 +16,10 @@ ] SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) +SYMBOLIC_LANG_PROMPTS = [ + "勇敢な船乗りについての詩を書く", # japanese + "寫一首關於勇敢的水手的詩", # chinese +] # for function calling TOOLS = [{ @@ -131,6 +135,26 @@ def test_mistral_format( ) +@pytest.mark.parametrize("model", MODELS[1:]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("prompt", SYMBOLIC_LANG_PROMPTS) +def test_mistral_symbolic_languages( + model: str, + dtype: str, + prompt: str, +) -> None: + prompt = "hi" + msg = {"role": "user", "content": prompt} + llm = LLM(model=model, + dtype=dtype, + max_model_len=8192, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral") + outputs = llm.chat([msg], sampling_params=SAMPLING_PARAMS) + assert "�" not in outputs[0].outputs[0].text.strip() + + @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling def test_mistral_function_calling( diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 7a228a3efa6e8..788133059f12d 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -175,10 +175,29 @@ def apply_chat_template(self, def convert_tokens_to_string(self, tokens: List[str]) -> str: if isinstance(self.tokenizer, Tekkenizer): - return "".join(t for t in tokens - if t not in self.tokenizer._all_special_tokens) + tokens = [ + t for t in tokens + if t not in self.tokenizer._all_special_tokens + ] + + if any(isinstance(t, bytes) for t in tokens): + # we need to encode and decode all tokens again + shift = self.tokenizer.num_special_tokens + byte_tokens = [ + t.encode("utf-8") if not isinstance(t, bytes) else t + for t in tokens + ] + ids = [ + self.tokenizer._tekken_token2id_nospecial[t] + shift + for t in byte_tokens + ] + decoded = self.tokenizer.decode(ids) + else: + decoded = "".join(tokens) else: - return self.tokenizer.decode(tokens) # type: ignore[arg-type] + decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type] + + return decoded def decode(self, ids: Union[List[int], int]) -> str: if isinstance(ids, int): @@ -200,4 +219,11 @@ def convert_ids_to_tokens( self.tokenizer) tokens = [self.tokenizer.id_to_piece(id) for id in ids] + + if any(t.strip() == "�" for t in tokens): + # if any stripped decoded token is undefined + # because it's invalid unicode then pass bytes + # See: https://github.com/vllm-project/vllm/pull/8640 + tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids] + return tokens From 7c8566aa4ff16b79a576436fbb50f03643febf07 Mon Sep 17 00:00:00 2001 From: omrishiv <327609+omrishiv@users.noreply.github.com> Date: Fri, 20 Sep 2024 15:04:37 -0700 Subject: [PATCH 058/116] [Doc] neuron documentation update (#8671) Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> --- docs/source/getting_started/neuron-installation.rst | 4 ++-- docs/source/index.rst | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/getting_started/neuron-installation.rst b/docs/source/getting_started/neuron-installation.rst index 0816524468cab..a9ed4d7fa2cd7 100644 --- a/docs/source/getting_started/neuron-installation.rst +++ b/docs/source/getting_started/neuron-installation.rst @@ -3,8 +3,8 @@ Installation with Neuron ======================== -vLLM 0.3.3 onwards supports model inferencing and serving on AWS Trainium/Inferentia with Neuron SDK. -At the moment Paged Attention is not supported in Neuron SDK, but naive continuous batching is supported in transformers-neuronx. +vLLM 0.3.3 onwards supports model inferencing and serving on AWS Trainium/Inferentia with Neuron SDK with continuous batching. +Paged Attention and Chunked Prefill are currently in development and will be available soon. Data types currently supported in Neuron SDK are FP16 and BF16. Requirements diff --git a/docs/source/index.rst b/docs/source/index.rst index 79f723eace762..803d412befb09 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -43,7 +43,7 @@ vLLM is flexible and easy to use with: * Tensor parallelism and pipeline parallelism support for distributed inference * Streaming outputs * OpenAI-compatible API server -* Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron. +* Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Trainium and Inferentia Accelerators. * Prefix caching support * Multi-lora support From 7f9c8902e3d50a9d715b38e0531280a58d2bbe14 Mon Sep 17 00:00:00 2001 From: omrishiv <327609+omrishiv@users.noreply.github.com> Date: Fri, 20 Sep 2024 15:19:44 -0700 Subject: [PATCH 059/116] [Hardware][AWS] update neuron to 2.20 (#8676) Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com> --- Dockerfile.neuron | 4 ++-- requirements-neuron.txt | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Dockerfile.neuron b/Dockerfile.neuron index f0c3479625a70..647ed99a41e70 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -1,5 +1,5 @@ # default base image -ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.19.1-ubuntu20.04" +ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.0-ubuntu20.04" FROM $BASE_IMAGE @@ -20,7 +20,7 @@ RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas RUN python3 -m pip install sentencepiece transformers==4.36.2 -U RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U COPY ./vllm /app/vllm/vllm COPY ./setup.py /app/vllm/setup.py diff --git a/requirements-neuron.txt b/requirements-neuron.txt index 92b705b4b2d67..148fdbe0d6310 100644 --- a/requirements-neuron.txt +++ b/requirements-neuron.txt @@ -2,6 +2,6 @@ -r requirements-common.txt # Dependencies for Neuron devices -transformers-neuronx >= 0.9.0 -torch-neuronx >= 2.1.0 +transformers-neuronx >= 0.12.0 +torch-neuronx >= 2.1.2 neuronx-cc From 0f961b3ce9ac3d3fd13e201c4358884bc094905e Mon Sep 17 00:00:00 2001 From: zyddnys Date: Fri, 20 Sep 2024 18:48:32 -0400 Subject: [PATCH 060/116] [Bugfix] Fix incorrect llava next feature size calculation (#8496) --- vllm/model_executor/models/llava_next.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c6bd46dd7eda9..d550a249ee822 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -87,17 +87,19 @@ def _get_llava_next_num_unpadded_features( current_height = npatches * num_patch_height current_width = npatches * num_patch_width - aspect_ratio = original_width / original_height + original_aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height - if aspect_ratio > current_aspect_ratio: - new_height = (original_height * current_width) // original_width + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) padding = (current_height - new_height) // 2 - current_height -= padding * 2 + current_height -= 2 * padding else: - new_width = (original_width * current_height) // original_height + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) padding = (current_width - new_width) // 2 - current_width -= padding * 2 + current_width -= 2 * padding unpadded_features = current_height * current_width newline_features = current_height From 0057894ef7f8db0d51385aa7254219d7fbd6c784 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 21 Sep 2024 10:00:54 +0800 Subject: [PATCH 061/116] [Core] Rename `PromptInputs` and `inputs`(#8673) --- benchmarks/benchmark_latency.py | 8 +- .../dev/multimodal/multimodal_index.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 2 +- docs/source/models/vlm.rst | 2 +- tests/mq_llm_engine/test_error_handling.py | 12 +-- tests/mq_llm_engine/utils.py | 2 +- vllm/__init__.py | 4 +- vllm/engine/async_llm_engine.py | 24 +++--- vllm/engine/llm_engine.py | 9 +- vllm/engine/multiprocessing/__init__.py | 4 +- vllm/engine/multiprocessing/client.py | 20 ++--- vllm/engine/multiprocessing/engine.py | 2 +- vllm/engine/protocol.py | 8 +- vllm/entrypoints/llm.py | 80 +++++++++-------- vllm/inputs/__init__.py | 6 +- vllm/inputs/data.py | 26 +++--- vllm/inputs/parse.py | 22 ++--- vllm/inputs/preprocess.py | 86 +++++++++---------- 18 files changed, 157 insertions(+), 162 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a39d1cf842f06..eadf994cacd34 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptInputs] = [{ + dummy_prompts: List[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 241b2ccd0991e..e112b43aade5e 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 9adf82d43f3e0..0d47281db485e 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptInputs +.. autodata:: vllm.inputs.PromptType .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 08db891665044..ca5b125369c85 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 49cfc5aa04c36..7c466c92d5293 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket): # Throws an error in first forward pass. with pytest.raises(RAISED_ERROR): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket): # Engine is errored, should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket): # Generate call should throw ENGINE_DEAD_ERROR with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -165,7 +165,7 @@ async def bad_abort_after_2s(): # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( - inputs="Hello my name is", + prompt="Hello my name is", sampling_params=SamplingParams(max_tokens=2000), request_id=uuid.uuid4()): pass @@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket): # Invalid request should fail, but not crash the server. with pytest.raises(ValueError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( @@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket): pass # This request should be okay. - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-2"): pass diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index e27fd77923412..3ffa126070ca0 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -20,7 +20,7 @@ async def generate( count = 0 async for out in client.generate( request_id=request_id, - inputs="Hello my name is Robert and", + prompt="Hello my name is Robert and", sampling_params=SamplingParams(max_tokens=num_tokens, temperature=0)): diff --git a/vllm/__init__.py b/vllm/__init__.py index 0895c571d1d89..59af68fb493e5 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,7 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -19,7 +19,7 @@ "__version__", "LLM", "ModelRegistry", - "PromptInputs", + "PromptType", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 34e7e05341f02..f108751056ab5 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -17,7 +17,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -405,7 +405,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None: async def add_request_async( self, request_id: str, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -420,7 +420,7 @@ async def add_request_async( arrival_time = time.time() preprocessed_inputs = await self.input_preprocessor.preprocess_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -777,7 +777,7 @@ async def run_engine_loop(engine_ref: ReferenceType): async def add_request( self, request_id: str, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -797,7 +797,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, - inputs=inputs, + prompt=prompt, params=params, arrival_time=arrival_time or time.time(), lora_request=lora_request, @@ -808,7 +808,7 @@ async def add_request( async def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -822,8 +822,7 @@ async def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -881,7 +880,7 @@ async def generate( """ async for output in await self.add_request( request_id, - inputs, + prompt, sampling_params, lora_request=lora_request, trace_headers=trace_headers, @@ -891,7 +890,7 @@ async def generate( async def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -904,8 +903,7 @@ async def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -959,7 +957,7 @@ async def encode( """ async for output in await self.add_request( request_id, - inputs, + prompt, pooling_params, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2743d5c7d2282..39409757d3812 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -29,7 +29,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptInputs) + InputRegistry, LLMInputs, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -680,7 +680,7 @@ def stop_remote_worker_execution_loop(self) -> None: def add_request( self, request_id: str, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -695,8 +695,7 @@ def add_request( Args: request_id: The unique ID of the request. - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. @@ -736,7 +735,7 @@ def add_request( arrival_time = time.time() preprocessed_inputs = self.input_preprocessor.preprocess( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 700332864d17a..09aa279f1e22c 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -3,7 +3,7 @@ from typing import List, Mapping, Optional, Union from vllm import PoolingParams -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest @@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError): @dataclass class RPCProcessRequest: - inputs: PromptInputs + prompt: PromptType params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index aa9dbbd448af2..71099115ea125 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -25,7 +25,7 @@ RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -375,7 +375,7 @@ def dead_error(self) -> BaseException: def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -389,8 +389,7 @@ def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -399,13 +398,13 @@ def generate( prompt_adapter_request: Prompt Adapter request to use for generation, if any. """ - return self._process_request(inputs, sampling_params, request_id, + return self._process_request(prompt, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request) def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -418,8 +417,7 @@ def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -430,12 +428,12 @@ def encode( The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. """ - return self._process_request(inputs, pooling_params, request_id, + return self._process_request(prompt, pooling_params, request_id, lora_request, trace_headers) async def _process_request( self, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, @@ -468,7 +466,7 @@ async def _process_request( request_bytes = pickle.dumps( RPCProcessRequest( - inputs=inputs, + prompt=prompt, params=params, request_id=request_id, lora_request=lora_request, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index f4ca231570853..788c1573ae255 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -245,7 +245,7 @@ def _handle_process_request(self, request: RPCProcessRequest): try: self.engine.add_request( request_id=request_id, - inputs=request.inputs, + prompt=request.prompt, params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 70444faa670a2..d0bbeb357b506 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptInputs +from vllm.inputs.data import PromptType from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -35,19 +35,19 @@ def dead_error(self) -> BaseException: def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: - """Generates outputs for a request""" + """Generate outputs for a request.""" ... def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 248b070611cd2..c7548ca4bcfbd 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -10,7 +10,7 @@ apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages) -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -258,8 +258,8 @@ def generate( @overload def generate( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], - /, # We may enable `inputs` keyword after removing the old API + prompts: Union[PromptType, Sequence[PromptType]], + /, *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -276,7 +276,7 @@ def generate( ) def generate( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -294,7 +294,9 @@ def generate( into a single list and pass it to this method. Args: - inputs: A list of inputs to generate completions for. + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -320,12 +322,13 @@ def generate( "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -340,7 +343,7 @@ def generate( sampling_params = SamplingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -396,9 +399,9 @@ def chat( conversation, mm_data = parse_chat_messages(messages, model_config, tokenizer) - prompt: Union[str, List[int]] + prompt_data: Union[str, List[int]] if isinstance(tokenizer, MistralTokenizer): - prompt = apply_mistral_chat_template( + prompt_data = apply_mistral_chat_template( tokenizer, messages=messages, chat_template=chat_template, @@ -406,7 +409,7 @@ def chat( tools=tools, ) else: - prompt = apply_hf_chat_template( + prompt_data = apply_hf_chat_template( tokenizer, conversation=conversation, chat_template=chat_template, @@ -414,17 +417,17 @@ def chat( tools=tools, ) - inputs: PromptInputs - if is_list_of(prompt, int): - inputs = TokensPrompt(prompt_token_ids=prompt) + prompt: PromptType + if is_list_of(prompt_data, int): + prompt = TokensPrompt(prompt_token_ids=prompt_data) else: - inputs = TextPrompt(prompt=prompt) + prompt = TextPrompt(prompt=prompt_data) if mm_data is not None: - inputs["multi_modal_data"] = mm_data + prompt["multi_modal_data"] = mm_data return self.generate( - inputs, + prompt, sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, @@ -494,8 +497,8 @@ def encode( @overload def encode( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], - /, # We may enable `inputs` keyword after removing the old API + prompts: Union[PromptType, Sequence[PromptType]], + /, *, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -512,7 +515,7 @@ def encode( ) def encode( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -528,9 +531,9 @@ def encode( into a single list and pass it to this method. Args: - inputs: The inputs to the LLM. You may pass a sequence of inputs for - batch inference. See :class:`~vllm.inputs.PromptInputs` - for more details about the format of each input. + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. @@ -553,19 +556,20 @@ def encode( ) if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -609,9 +613,9 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - inputs: List[PromptInputs] = [] + parsed_prompts: List[PromptType] = [] for i in range(num_requests): - item: PromptInputs + item: PromptType if prompts is not None: item = TextPrompt(prompt=prompts[i]) @@ -620,24 +624,24 @@ def _convert_v1_inputs( else: raise AssertionError - inputs.append(item) + parsed_prompts.append(item) - return inputs + return parsed_prompts def _validate_and_add_requests( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[PromptType, Sequence[PromptType]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, ) -> None: - if isinstance(inputs, (str, dict)): + if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. - inputs = [inputs] + prompts = [prompts] - num_requests = len(inputs) + num_requests = len(prompts) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -654,9 +658,9 @@ def _validate_and_add_requests( sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - for i, request_inputs in enumerate(inputs): + for i, prompt in enumerate(prompts): self._add_request( - request_inputs, + prompt, params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, @@ -665,7 +669,7 @@ def _validate_and_add_requests( def _add_request( self, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -673,7 +677,7 @@ def _add_request( request_id = str(next(self.request_counter)) self.llm_engine.add_request( request_id, - inputs, + prompt, params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 0b08e9691f915..ba1bef1ab3ecc 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,5 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -16,8 +16,8 @@ __all__ = [ "TextPrompt", "TokensPrompt", - "PromptInputs", - "SingletonPromptInputs", + "PromptType", + "SingletonPrompt", "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 75ab0c770155b..e072bb65714b9 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -33,7 +33,7 @@ class TokensPrompt(TypedDict): """ -SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] +SingletonPrompt = Union[str, TextPrompt, TokensPrompt] """ Set of possible schemas for a single LLM input: @@ -46,7 +46,7 @@ class TokensPrompt(TypedDict): the user desires to express both the encoder & decoder prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type :class:`SingletonPromptInputs` may be employed +A prompt of type :class:`SingletonPromptType` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or @@ -55,12 +55,12 @@ class TokensPrompt(TypedDict): """ _T1_co = TypeVar("_T1_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) _T2_co = TypeVar("_T2_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) @@ -72,7 +72,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): The encoder and decoder prompts, respectively, may formatted according to any of the - :class:`SingletonPromptInputs` schemas, and are not + :class:`SingletonPromptType` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. @@ -81,7 +81,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): be used as an input to a decoder-only model, and that the `encoder_prompt` and `decoder_prompt` fields of this data structure themselves must be - :class:`SingletonPromptInputs` instances. + :class:`SingletonPromptType` instances. """ encoder_prompt: _T1_co @@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] -PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] +PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: @@ -140,12 +140,8 @@ class EncoderDecoderLLMInputs(LLMInputs): """ -_T1 = TypeVar("_T1", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) -_T2 = TypeVar("_T2", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) +_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) +_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) def build_explicit_enc_dec_prompt( diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index ac9d355c64c80..e5fa1e4184277 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,7 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, TokensPrompt) @@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict): def parse_singleton_prompt( - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: - if isinstance(inputs, str): - return ParsedStrPrompt(type="str", content=inputs) - elif isinstance(inputs, dict): - if "prompt_token_ids" in inputs: + if isinstance(prompt, str): + return ParsedStrPrompt(type="str", content=prompt) + elif isinstance(prompt, dict): + if "prompt_token_ids" in prompt: return ParsedTokensPrompt(type="tokens", - content=inputs) # type: ignore - elif "prompt" in inputs: - return ParsedTextPrompt(type="text", content=inputs) + content=prompt) # type: ignore + elif "prompt" in prompt: + return ParsedTextPrompt(type="text", content=prompt) raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( - inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: - return isinstance(inputs, dict) and "encoder_prompt" in inputs + prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(prompt, dict) and "encoder_prompt" in prompt def is_valid_encoder_decoder_llm_inputs( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index be2aa5f8cb7d0..1f1b048d37e9b 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -9,8 +9,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, - SingletonPromptInputs) +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, + SingletonPrompt) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt if TYPE_CHECKING: @@ -206,7 +206,7 @@ async def _tokenize_prompt_async( def _extract_prompt_components( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: @@ -216,7 +216,7 @@ def _extract_prompt_components( Arguments: * request_id - * inputs: single encoder or decoder input prompt + * prompt: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts Returns: @@ -226,24 +226,24 @@ def _extract_prompt_components( * multi_modal_data ''' - parsed = parse_singleton_prompt(inputs) + parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": - prompt = parsed["content"] + prompt_text = parsed["content"] prompt_token_ids = self._tokenize_prompt( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt = None + prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt = parsed["content"]["prompt"] + prompt_text = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) @@ -251,33 +251,33 @@ def _extract_prompt_components( else: assert_never(parsed) - return prompt, prompt_token_ids, multi_modal_data + return prompt_text, prompt_token_ids, multi_modal_data async def _extract_prompt_components_async( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: """Async version of :meth:`_extract_prompt_components`.""" - parsed = parse_singleton_prompt(inputs) + parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": - prompt = parsed["content"] + prompt_text = parsed["content"] prompt_token_ids = await self._tokenize_prompt_async( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt = None + prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt = parsed["content"]["prompt"] + prompt_text = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) @@ -285,7 +285,7 @@ async def _extract_prompt_components_async( else: assert_never(parsed) - return prompt, prompt_token_ids, multi_modal_data + return prompt_text, prompt_token_ids, multi_modal_data def _build_enc_dec_llm_inputs( self, @@ -311,7 +311,7 @@ def _build_enc_dec_llm_inputs( def _process_encoder_decoder_prompt( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, ) -> EncoderDecoderLLMInputs: ''' @@ -339,7 +339,7 @@ def _process_encoder_decoder_prompt( Arguments: - * inputs: an input prompt + * prompt: an input prompt * request_id Returns: @@ -350,13 +350,13 @@ def _process_encoder_decoder_prompt( encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): encoder_comps = self._extract_prompt_components( - inputs["encoder_prompt"], + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + if (decoder_input := prompt["decoder_prompt"]) is None: decoder_comps = None, None, None else: decoder_comps = self._extract_prompt_components( @@ -365,7 +365,7 @@ def _process_encoder_decoder_prompt( ) else: encoder_comps = self._extract_prompt_components( - inputs, + prompt, request_id=request_id, ) @@ -375,20 +375,20 @@ def _process_encoder_decoder_prompt( async def _process_encoder_decoder_prompt_async( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): encoder_task = self._extract_prompt_components_async( - inputs["encoder_prompt"], + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + if (decoder_input := prompt["decoder_prompt"]) is None: encoder_comps = await encoder_task decoder_comps = None, None, None else: @@ -401,7 +401,7 @@ async def _process_encoder_decoder_prompt_async( encoder_task, decoder_task) else: encoder_comps = await self._extract_prompt_components_async( - inputs, + prompt, request_id=request_id, ) @@ -425,7 +425,7 @@ def _build_decoder_only_llm_inputs( def _process_decoder_only_prompt( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -436,7 +436,7 @@ def _process_decoder_only_prompt( Arguments: - * inputs: input prompt + * prompt: input prompt * request_id * lora_request * prompt_adapter_request @@ -447,7 +447,7 @@ def _process_decoder_only_prompt( ''' prompt_comps = self._extract_prompt_components( - inputs, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -459,14 +459,14 @@ def _process_decoder_only_prompt( async def _process_decoder_only_prompt_async( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -478,7 +478,7 @@ async def _process_decoder_only_prompt_async( def preprocess( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -488,17 +488,17 @@ def preprocess( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return self._process_decoder_only_prompt( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -506,7 +506,7 @@ def preprocess( async def preprocess_async( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -516,17 +516,17 @@ async def preprocess_async( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return await self._process_decoder_only_prompt_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, From d4bf085ad064ba68a77862e2022f37c33a66e94a Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Sat, 21 Sep 2024 10:03:55 +0800 Subject: [PATCH 062/116] [MISC] add support custom_op check (#8557) Co-authored-by: youkaichao --- vllm/distributed/parallel_state.py | 49 ++++++++++++++++-------------- vllm/utils.py | 6 ++++ 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index df07842edfa56..d3ac4eb78b155 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -36,6 +36,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import supports_custom_op @dataclass @@ -95,32 +96,33 @@ def _register_group(group: "GroupCoordinator") -> None: _groups[group.unique_name] = weakref.ref(group) # type: ignore -@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"]) -def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - group._all_reduce(tensor) +if supports_custom_op(): + @torch.library.custom_op("vllm::inplace_all_reduce", + mutates_args=["tensor"]) + def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_reduce(tensor) -@inplace_all_reduce.register_fake -def _(tensor: torch.Tensor, group_name: str) -> None: - return - - -@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) -def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce(tensor) + @inplace_all_reduce.register_fake + def _(tensor: torch.Tensor, group_name: str) -> None: + return + @torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) + def outplace_all_reduce(tensor: torch.Tensor, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce(tensor) -@outplace_all_reduce.register_fake -def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: - return torch.empty_like(tensor) + @outplace_all_reduce.register_fake + def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) class GroupCoordinator: @@ -335,6 +337,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.world_size == 1: return input_ + if not supports_custom_op(): + return self._all_reduce(input_) + if self.tpu_communicator is not None and \ not self.tpu_communicator.disabled: # TPU handles Dynamo with its own logic. diff --git a/vllm/utils.py b/vllm/utils.py index 060b387ec7834..43b64263d645a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1245,6 +1245,12 @@ def supports_dynamo() -> bool: return base_torch_version >= Version("2.4.0") +# Some backends use pytorch version < 2.4.0 which doesn't +# support `torch.library.custom_op`. +def supports_custom_op() -> bool: + return hasattr(torch.library, "custom_op") + + class AtomicCounter: """An atomic, thread-safe counter""" From 0455c46ed434d70f0a6219204e89ee04f1d01336 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 21 Sep 2024 10:30:39 +0800 Subject: [PATCH 063/116] [Core] Factor out common code in `SequenceData` and `Sequence` (#8675) --- tests/samplers/test_sampler.py | 27 +++----- tests/spec_decode/utils.py | 12 +--- tests/test_logits_processor.py | 8 +-- tests/test_sequence.py | 7 +-- .../test_encoder_decoder_model_runner.py | 22 +++---- tests/worker/test_model_runner.py | 16 ++--- vllm/inputs/registry.py | 8 +-- vllm/sequence.py | 61 +++++++++++-------- 8 files changed, 64 insertions(+), 97 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 19a5ca5e27502..308b708feab71 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,6 +1,5 @@ import itertools import random -from array import array from typing import Dict, List, Optional, Tuple from unittest.mock import Mock, patch @@ -12,8 +11,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import Counter, is_pin_memory_available @@ -59,9 +57,7 @@ def _do_sample( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -205,9 +201,8 @@ def create_sampling_params(min_tokens, return sampling_params def create_sequence_data(num_input=3, num_generated=0): - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, - random.choices(range(0, VOCAB_SIZE), k=num_input))) + seq_data = SequenceData.from_seqs( + random.choices(range(0, VOCAB_SIZE), k=num_input)) if num_generated > 0: seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), k=num_generated) @@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=SamplingParams( temperature=1, top_k=top_k, @@ -699,11 +690,7 @@ def test_sampling_params(sampling_params: List[SamplingParams]): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: - SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - [1, 2, 3])) - }, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=sampling_params[i], block_tables={0: [1]}, )) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 9075a433eb66e..f17e872881633 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -1,4 +1,3 @@ -from array import array from itertools import count from typing import Callable, Dict, List, Optional from typing import Sequence as GenericSequence @@ -11,8 +10,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, - CompletionSequenceGroupOutput, Logprob, +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine @@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts( request_id=str(i), is_prompt=len(cont_token_ids) == 0, seq_data={ - i: - SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]), - _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, - cont_token_ids[:]), - ), + i: SequenceData.from_seqs(prompt_token_ids[:], + cont_token_ids[:]), }, sampling_params=SamplingParams(temperature=0.0, ), block_tables={i: block_allocations[i][:]}, diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 1ce49a50688ae..39c1c38151fd0 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -1,5 +1,4 @@ import random -from array import array from typing import Tuple from unittest.mock import patch @@ -9,8 +8,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import is_pin_memory_available @@ -71,9 +69,7 @@ def pick_ith(token_ids, logits): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), block_tables={0: [1]}, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 348ba7dd41d99..30e53a180ea31 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,10 +1,7 @@ -from array import array - import pytest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, - CompletionSequenceGroupOutput, SequenceData, +from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData, SequenceOutput) from .core.utils import create_dummy_prompt @@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs): def test_sequence_data_prefill(): - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4])) + seq_data = SequenceData.from_seqs([1, 2, 3, 4]) assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_computed_tokens() == 0 # advance by 2 diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 27cdf5f339ede..3dccc1b325d95 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -1,13 +1,11 @@ import itertools -from array import array from typing import List import pytest import torch from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import is_cpu, make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import _get_graph_batch_size @@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - range(seq_len))) + seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len))) + encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) + seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) + seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 42b2337f46914..fe97199bac62d 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,4 +1,3 @@ -from array import array from typing import List import pytest @@ -8,8 +7,7 @@ init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - range(seq_len))) + seq_data = SequenceData.from_seqs(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))) + seq_data = SequenceData.from_seqs(range(context_len)) seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. seq_data.append_token_id(1, 0) @@ -326,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - range(seq_len))) + seq_data = SequenceData.from_seqs(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -343,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)) - seq_data = SequenceData(prompt_toks) + seq_data = SequenceData.from_seqs(range(context_len)) seq_data.append_token_id(1, 0) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ae6c6c05d9f72..a0f02ba29e219 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,5 +1,4 @@ import functools -from array import array from collections import UserDict from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, @@ -22,10 +21,6 @@ C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) -# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE. -# We cannot import it here because of circular dependencies. -VLLM_TOKEN_ID_ARRAY_TYPE = "l" - @dataclass(frozen=True) class InputContext: @@ -130,8 +125,7 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceData - dummy_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len) + dummy_seq_data = SequenceData.from_counts({0: seq_len}) dummy_multi_modal_data = None return dummy_seq_data, dummy_multi_modal_data diff --git a/vllm/sequence.py b/vllm/sequence.py index 07ceccf123541..f849211c317ca 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,6 +5,7 @@ from array import array from collections import defaultdict from dataclasses import dataclass +from functools import cached_property, reduce from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence from typing import Set, Tuple, Union, cast @@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct, # It is used to compute mrope_position_ids. _mrope_position_delta: Optional[int] = None + @staticmethod + def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData": + if len(counts_by_token) == 0: + return SequenceData.from_seqs([]) + + arrs = [ + array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count + for token_id, count in counts_by_token.items() + ] + + return SequenceData(reduce(array.__add__, arrs)) + + @staticmethod + def from_seqs( + prompt_token_ids: GenericSequence[int], + output_token_ids: Optional[GenericSequence[int]] = None, + ) -> "SequenceData": + prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + prompt_token_ids) + + if output_token_ids is None: + return SequenceData(prompt_token_ids_arr) + + output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + output_token_ids) + + return SequenceData(prompt_token_ids_arr, + _output_token_ids=output_token_ids_arr) + def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l" @@ -370,8 +400,6 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request self.from_decoder_prompt = from_decoder_prompt - self._prompt: Optional[str] = None - self._prompt_token_ids: Optional[List[int]] = None # For decoder-only models, a Sequence is constructed # from an LLMInputs instance (the `inputs` arg.) @@ -400,8 +428,7 @@ def __init__( f"invalid input {inputs}; did you forget the " "encoder input prompt fields?") - self.data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids)) + self.data = SequenceData.from_seqs(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -422,37 +449,23 @@ def __init__( def n_blocks(self) -> int: return (self.get_len() + self.block_size - 1) // self.block_size - @property + @cached_property def prompt(self) -> Optional[str]: - if self._prompt is not None: - # Reuse precomputed prompt string - return self._prompt - - # Select decoder or encoder input prompt str, - # as appropriate + # Select decoder or encoder input prompt str, as appropriate prompt_key: str = ("prompt" if self.from_decoder_prompt else "encoder_prompt") - # Cache prompt - self._prompt = cast(Optional[str], self.inputs.get(prompt_key)) - return self._prompt + return cast(Optional[str], self.inputs.get(prompt_key)) - @property + @cached_property def prompt_token_ids(self) -> List[int]: - if self._prompt_token_ids is not None: - # Reuse precomputed prompt token ids - return self._prompt_token_ids - - # Select decoder or encoder input prompt - # token ids, as appropriate + # Select decoder or encoder input prompt token ids, as appropriate prompt_token_ids_key: str = ("prompt_token_ids" if self.from_decoder_prompt else "encoder_prompt_token_ids") # Cache computed prompt token ids - self._prompt_token_ids = cast(List[int], - self.inputs.get(prompt_token_ids_key)) - return self._prompt_token_ids + return cast(List[int], self.inputs.get(prompt_token_ids_key)) @property def multi_modal_data(self) -> "MultiModalDataDict": From 0faab90eb006c677add65cd4c2d0f740a63e064d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 20 Sep 2024 19:55:33 -0700 Subject: [PATCH 064/116] [beam search] add output for manually checking the correctness (#8684) --- tests/samplers/test_beam_search.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 64f3ce94b7a83..98a02dec895d2 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -11,7 +11,7 @@ # 3. Use the model "huggyllama/llama-7b". MAX_TOKENS = [128] BEAM_WIDTHS = [4] -MODELS = ["facebook/opt-125m"] +MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"] @pytest.mark.parametrize("model", MODELS) @@ -37,8 +37,15 @@ def test_beam_search_single_input( beam_width, max_tokens) for i in range(len(example_prompts)): - hf_output_ids, _ = hf_outputs[i] - vllm_output_ids, _ = vllm_outputs[i] + hf_output_ids, hf_output_texts = hf_outputs[i] + vllm_output_ids, vllm_output_texts = vllm_outputs[i] + for i, (hf_text, + vllm_text) in enumerate(zip(hf_output_texts, + vllm_output_texts)): + print(f">>>{i}-th hf output:") + print(hf_text) + print(f">>>{i}-th vllm output:") + print(vllm_text) assert len(hf_output_ids) == len(vllm_output_ids) for j in range(len(hf_output_ids)): assert hf_output_ids[j] == vllm_output_ids[j], ( From 71c60491f287d8a23bed1743513b4b3e7927c69e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 21 Sep 2024 02:27:10 -0400 Subject: [PATCH 065/116] [Kernel] Build flash-attn from source (#8245) --- .github/workflows/scripts/build.sh | 1 + .gitignore | 5 ++ CMakeLists.txt | 98 ++++++++++++++++++++------- Dockerfile | 3 + cmake/utils.cmake | 2 +- requirements-cuda.txt | 1 - setup.py | 38 ++++++++--- vllm/attention/backends/flash_attn.py | 9 ++- vllm/attention/selector.py | 8 +-- 9 files changed, 124 insertions(+), 41 deletions(-) diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index 0a759d303238b..cd617e9f19fb2 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -15,5 +15,6 @@ $python_executable -m pip install -r requirements-cuda.txt export MAX_JOBS=1 # Make sure release wheels are built for the following architectures export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" +export VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real" # Build $python_executable setup.py bdist_wheel --dist-dir=dist diff --git a/.gitignore b/.gitignore index 761b00ac3bc48..bc7236ea18698 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # vllm commit id, generated by setup.py vllm/commit_id.py +# vllm-flash-attn built from source +vllm/vllm_flash_attn/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -12,6 +15,8 @@ __pycache__/ # Distribution / packaging .Python build/ +cmake-build-*/ +CMakeUserPresets.json develop-eggs/ dist/ downloads/ diff --git a/CMakeLists.txt b/CMakeLists.txt index c8f19de94e59b..e0716af6fff4f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,16 @@ cmake_minimum_required(VERSION 3.26) +# When building directly using CMake, make sure you run the install step +# (it places the .so files in the correct location). +# +# Example: +# mkdir build && cd build +# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. .. +# cmake --build . --target install +# +# If you want to only build one target, make sure to install it manually: +# cmake --build . --target _C +# cmake --install . --component _C project(vllm_extensions LANGUAGES CXX) # CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) @@ -13,6 +24,9 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) # Suppress potential warnings about unused manually-specified variables set(ignoreMe "${VLLM_PYTHON_PATH}") +# Prevent installation of dependencies (cutlass) by default. +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) + # # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. @@ -70,19 +84,6 @@ endif() find_package(Torch REQUIRED) # -# Add the `default` target which detects which extensions should be -# built based on platform/architecture. This is the same logic that -# setup.py uses to select which extensions should be built and should -# be kept in sync. -# -# The `default` target makes direct use of cmake easier since knowledge -# of which extensions are supported has been factored in, e.g. -# -# mkdir build && cd build -# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. -# cmake --build . --target default -# -add_custom_target(default) message(STATUS "Enabling core extension.") # Define _core_C extension @@ -100,8 +101,6 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) -add_dependencies(default _core_C) - # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -167,6 +166,8 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") endif() +include(FetchContent) + # # Define other extension targets # @@ -190,7 +191,6 @@ set(VLLM_EXT_SRC "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") - include(FetchContent) SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") FetchContent_Declare( cutlass @@ -283,6 +283,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") csrc/quantization/machete/machete_pytorch.cu) endif() +message(STATUS "Enabling C extension.") define_gpu_extension_target( _C DESTINATION vllm @@ -313,6 +314,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/moe/marlin_moe_ops.cu") endif() +message(STATUS "Enabling moe extension.") define_gpu_extension_target( _moe_C DESTINATION vllm @@ -323,7 +325,6 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) - if(VLLM_GPU_LANG STREQUAL "HIP") # # _rocm_C extension @@ -343,16 +344,63 @@ if(VLLM_GPU_LANG STREQUAL "HIP") WITH_SOABI) endif() +# vllm-flash-attn currently only supported on CUDA +if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda") + return() +endif () -if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") - message(STATUS "Enabling C extension.") - add_dependencies(default _C) +# +# Build vLLM flash attention from source +# +# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM. +# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs. +# They should be identical but if they aren't, this is a massive footgun. +# +# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. +# To only install vllm-flash-attn, use --component vllm_flash_attn_c. +# If no component is specified, vllm-flash-attn is still installed. - message(STATUS "Enabling moe extension.") - add_dependencies(default _moe_C) +# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. +# This is to enable local development of vllm-flash-attn within vLLM. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) + set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR}) endif() -if(VLLM_GPU_LANG STREQUAL "HIP") - message(STATUS "Enabling rocm extension.") - add_dependencies(default _rocm_C) +if(VLLM_FLASH_ATTN_SRC_DIR) + FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR}) +else() + FetchContent_Declare( + vllm-flash-attn + GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git + GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd + GIT_PROGRESS TRUE + ) endif() + +# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization. +set(VLLM_PARENT_BUILD ON) + +# Make sure vllm-flash-attn install rules are nested under vllm/ +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c) +install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) +install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c) + +# Fetch the vllm-flash-attn library +FetchContent_MakeAvailable(vllm-flash-attn) +message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") + +# Restore the install prefix +install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c) + +# Copy over the vllm-flash-attn python files +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm/vllm_flash_attn + COMPONENT vllm_flash_attn_c + FILES_MATCHING PATTERN "*.py" +) + +# Nothing after vllm-flash-attn, see comment about macros above diff --git a/Dockerfile b/Dockerfile index 001068b4b36ca..30e27620574a0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -48,6 +48,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ # see https://github.com/pytorch/pytorch/pull/123243 ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} +# Override the arch list for flash-attn to reduce the binary size +ARG vllm_fa_cmake_gpu_arches='80-real;90-real' +ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches} #################### BASE BUILD IMAGE #################### #################### WHEEL BUILD IMAGE #################### diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 730517a20129a..10fa0a25bde15 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -364,5 +364,5 @@ function (define_gpu_extension_target GPU_MOD_NAME) target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}) endif() - install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION}) + install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME}) endfunction() diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 5b811703a55e7..3b3c2f876919e 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,4 +8,3 @@ torch == 2.4.0 # These must be updated alongside torch torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0 -vllm-flash-attn == 2.6.1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0 diff --git a/setup.py b/setup.py index 7da9115440433..cc559f26c6f3f 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ import subprocess import sys import warnings +from pathlib import Path from shutil import which from typing import Dict, List @@ -152,15 +153,8 @@ def configure(self, ext: CMakeExtension) -> None: default_cfg = "Debug" if self.debug else "RelWithDebInfo" cfg = envs.CMAKE_BUILD_TYPE or default_cfg - # where .so files will be written, should be the same for all extensions - # that use the same CMakeLists.txt. - outdir = os.path.abspath( - os.path.dirname(self.get_ext_fullpath(ext.name))) - cmake_args = [ '-DCMAKE_BUILD_TYPE={}'.format(cfg), - '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir), - '-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp), '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), ] @@ -224,10 +218,12 @@ def build_extensions(self) -> None: os.makedirs(self.build_temp) targets = [] + target_name = lambda s: remove_prefix(remove_prefix(s, "vllm."), + "vllm_flash_attn.") # Build all the extensions for ext in self.extensions: self.configure(ext) - targets.append(remove_prefix(ext.name, "vllm.")) + targets.append(target_name(ext.name)) num_jobs, _ = self.compute_num_jobs() @@ -240,6 +236,28 @@ def build_extensions(self) -> None: subprocess.check_call(["cmake", *build_args], cwd=self.build_temp) + # Install the libraries + for ext in self.extensions: + # Install the extension into the proper location + outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute() + + # Skip if the install directory is the same as the build directory + if outdir == self.build_temp: + continue + + # CMake appends the extension prefix to the install path, + # and outdir already contains that prefix, so we need to remove it. + prefix = outdir + for i in range(ext.name.count('.')): + prefix = prefix.parent + + # prefix here should actually be the same for all components + install_args = [ + "cmake", "--install", ".", "--prefix", prefix, "--component", + target_name(ext.name) + ] + subprocess.check_call(install_args, cwd=self.build_temp) + def _no_device() -> bool: return VLLM_TARGET_DEVICE == "empty" @@ -467,6 +485,10 @@ def _read_requirements(filename: str) -> List[str]: if _is_hip(): ext_modules.append(CMakeExtension(name="vllm._rocm_C")) +if _is_cuda(): + ext_modules.append( + CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c")) + if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index bf883987bd80b..084e8113cd421 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -19,8 +19,13 @@ from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) -from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func -from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache +# yapf: disable +from vllm.vllm_flash_attn import ( + flash_attn_varlen_func as _flash_attn_varlen_func) +from vllm.vllm_flash_attn import ( + flash_attn_with_kvcache as _flash_attn_with_kvcache) + +# yapf: enable @torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[]) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index fbda263ba8e08..30aa7cb311afb 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -244,8 +244,7 @@ def which_attn_to_use( # FlashAttn is valid for the model, checking if the package is installed. if selected_backend == _Backend.FLASH_ATTN: try: - import vllm_flash_attn # noqa: F401 - + import vllm.vllm_flash_attn # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) @@ -258,8 +257,9 @@ def which_attn_to_use( except ImportError: logger.info( "Cannot use FlashAttention-2 backend because the " - "vllm_flash_attn package is not found. " - "`pip install vllm-flash-attn` for better performance.") + "vllm.vllm_flash_attn package is not found. " + "Make sure that vllm_flash_attn was built and installed " + "(on by default).") selected_backend = _Backend.XFORMERS return selected_backend From 5e85f4f82a5b6eaad6869198d6ac76a0c12cf6d0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 21 Sep 2024 14:28:56 +0800 Subject: [PATCH 066/116] [VLM] Use `SequenceData.from_token_counts` to create dummy data (#8687) --- vllm/inputs/registry.py | 2 +- vllm/model_executor/models/blip.py | 13 +++++------ vllm/model_executor/models/blip2.py | 13 +++++------ vllm/model_executor/models/chameleon.py | 13 +++++------ vllm/model_executor/models/clip.py | 12 +++++----- vllm/model_executor/models/minicpmv.py | 7 ++---- vllm/model_executor/models/pixtral.py | 14 +++++------- vllm/model_executor/models/qwen.py | 10 ++++----- vllm/model_executor/models/qwen2_vl.py | 21 ++++++++--------- vllm/model_executor/models/siglip.py | 12 +++++----- vllm/model_executor/models/ultravox.py | 30 ++++++++++++++++++------- vllm/sequence.py | 6 ++--- 12 files changed, 73 insertions(+), 80 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index a0f02ba29e219..2df61a9149629 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -125,7 +125,7 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceData - dummy_seq_data = SequenceData.from_counts({0: seq_len}) + dummy_seq_data = SequenceData.from_token_counts((0, seq_len)) dummy_multi_modal_data = None return dummy_seq_data, dummy_multi_modal_data diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 583d5d217903b..e943427eda8e1 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -1,6 +1,5 @@ """Minimal implementation of BlipVisionModel intended to be only used within a vision language model.""" -from array import array from typing import Optional, Union import torch @@ -19,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import SequenceData try: from xformers import ops as xops @@ -53,6 +52,7 @@ def get_max_blip_image_tokens( def dummy_seq_data_for_blip( hf_config: Union[BlipVisionConfig, Blip2VisionConfig], seq_len: int, + num_images: int, *, image_token_id: int, image_feature_size_override: Optional[int] = None, @@ -62,11 +62,10 @@ def dummy_seq_data_for_blip( else: image_feature_size = image_feature_size_override - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * image_feature_size - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - image_feature_size) - return SequenceData(token_ids) + return SequenceData.from_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) def dummy_image_for_blip( diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 39f2b2d853a6b..37fabf3f3f9a8 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,4 +1,3 @@ -from array import array from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -18,8 +17,7 @@ from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceData from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) @@ -429,11 +427,10 @@ def dummy_seq_data_for_blip2( else: image_feature_size = image_feature_size_override - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * image_feature_size * num_images - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceData.from_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) def dummy_data_for_blip2(ctx: InputContext, seq_len: int, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 47e020e8ecb73..51a61485caf65 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,4 +1,3 @@ -from array import array from functools import cached_property from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict) @@ -32,8 +31,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import print_warning_once from .interfaces import SupportsMultiModal @@ -72,11 +70,10 @@ def dummy_seq_data_for_chameleon( else: image_feature_size = image_feature_size_override - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * image_feature_size * num_images - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceData.from_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) def dummy_image_for_chameleon( diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 078928f281c26..a7754f70e2786 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,6 +1,5 @@ """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" -from array import array from typing import Iterable, List, Optional, Tuple, Union import torch @@ -20,7 +19,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import SequenceData try: from xformers import ops as xops @@ -62,11 +61,10 @@ def dummy_seq_data_for_clip( else: image_feature_size = image_feature_size_override - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * image_feature_size * num_images - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceData.from_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) def dummy_image_for_clip( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index f0fc950defed7..5579205832aa8 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -23,7 +23,6 @@ """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re -from array import array from functools import partial from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, TypedDict) @@ -56,8 +55,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceData from .idefics2_vision_model import Idefics2VisionTransformer @@ -259,8 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len - return SequenceData(token_ids) + return SequenceData.from_token_counts((0, seq_len)) def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int): diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 682b78bbed093..aa92e62a30d3f 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,4 +1,3 @@ -from array import array from dataclasses import dataclass, fields from itertools import tee from typing import Iterable, List, Mapping, Optional, Tuple, Union @@ -24,8 +23,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceData from .interfaces import SupportsMultiModal from .utils import init_vllm_registered_model @@ -63,13 +61,11 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, image_feature_size = (size**2) // (patch_size**2) num_image_tokens = image_feature_size * num_images + seq_data = SequenceData.from_token_counts( + (image_token_id, num_image_tokens), + (0, seq_len - num_image_tokens), + ) - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * num_image_tokens - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - num_image_tokens) - - seq_data = SequenceData(token_ids) mm_data = {"image": num_images * [image]} return seq_data, mm_data diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 18bc6b303f485..e62a841485f2d 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -7,7 +7,6 @@ import math import re -from array import array from functools import partial from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -45,8 +44,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of from .utils import flatten_bn, is_pp_missing_parameter, make_layers @@ -819,7 +817,7 @@ def dummy_data_for_qwen( # The presence of a visual config indicates this is a multimodal model. # If we don't have it, the model is considered an LLM for warmup purposes. if not hasattr(hf_config, "visual"): - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)) + seq_data = SequenceData.from_token_counts((0, seq_len)) mm_data = None return seq_data, mm_data @@ -846,11 +844,13 @@ def dummy_data_for_qwen( if len(toks) < seq_len: toks += [0] * (seq_len - len(toks)) + seq_data = SequenceData.from_seqs(toks) + # Build the input images; width/height doesn't actually matter here since # the data will get resized and the # of tokens per image is constant image = Image.new("RGB", (224, 224), color=0) mm_data = {"image": image if num_images == 1 else [image] * num_images} - return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data + return seq_data, mm_data @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a9a0329e99f08..1011c9256793e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,7 +22,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -from array import array from functools import lru_cache, partial from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union) @@ -66,8 +65,7 @@ from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor from vllm.platforms import current_platform -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import get_processor logger = init_logger(__name__) @@ -681,15 +679,14 @@ def dummy_data_for_qwen2_vl( "--limit-mm-per-prompt.") hf_config = ctx.get_hf_config(Qwen2VLConfig) - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [hf_config.vision_start_token_id]) - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [hf_config.image_token_id]) * max_llm_image_tokens - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [hf_config.vision_end_token_id]) - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - max_llm_image_tokens - 2) - dummy_seqdata = SequenceData(token_ids) + + dummy_seqdata = SequenceData.from_token_counts( + (hf_config.vision_start_token_id, 1), + (hf_config.image_token_id, max_llm_image_tokens), + (hf_config.vision_end_token_id, 1), + (0, seq_len - max_llm_image_tokens - 2), + ) + dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), color=0) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index f7976eba7420b..5b332fa1a24d7 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -2,7 +2,6 @@ within a vision language model.""" import math -from array import array from typing import Iterable, List, Optional, Tuple, Union import torch @@ -24,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import SequenceData try: from xformers import ops as xops @@ -67,11 +66,10 @@ def dummy_seq_data_for_siglip( else: image_feature_size = image_feature_size_override - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [image_token_id]) * image_feature_size - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - image_feature_size) - return SequenceData(token_ids) + return SequenceData.from_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) def dummy_image_for_siglip( diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 416fabda831a2..87f59f487f87b 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -77,15 +77,11 @@ def get_ultravox_max_audio_tokens(ctx: InputContext): return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) -def dummy_data_for_ultravox( +def dummy_seq_data_for_ultravox( ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int], + audio_count: int, ): - feature_extractor = whisper_feature_extractor(ctx) - - audio_count = mm_counts["audio"] - audio_placeholder = array( VLLM_TOKEN_ID_ARRAY_TYPE, [_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx) @@ -96,10 +92,28 @@ def dummy_data_for_ultravox( other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids)) + return SequenceData(audio_token_ids + other_token_ids) + + +def dummy_audio_for_ultravox( + ctx: InputContext, + audio_count: int, +): + feature_extractor = whisper_feature_extractor(ctx) audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) - mm_dict = {"audio": [audio_and_sr] * audio_count} + return {"audio": [audio_and_sr] * audio_count} + + +def dummy_data_for_ultravox( + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], +): + audio_count = mm_counts["audio"] + seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count) + mm_dict = dummy_audio_for_ultravox(ctx, audio_count) - return (SequenceData(audio_token_ids + other_token_ids), mm_dict) + return (seq_data, mm_dict) def input_mapper_for_ultravox(ctx: InputContext, data: object): diff --git a/vllm/sequence.py b/vllm/sequence.py index f849211c317ca..d8e54ff1fc708 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -171,13 +171,13 @@ class SequenceData(msgspec.Struct, _mrope_position_delta: Optional[int] = None @staticmethod - def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData": - if len(counts_by_token) == 0: + def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData": + if len(token_counts) == 0: return SequenceData.from_seqs([]) arrs = [ array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count - for token_id, count in counts_by_token.items() + for token_id, count in token_counts ] return SequenceData(reduce(array.__add__, arrs)) From 4dfdf4319676c3dca72cdfba20470ac76d0cadf4 Mon Sep 17 00:00:00 2001 From: Andy Dai <76841985+Imss27@users.noreply.github.com> Date: Sat, 21 Sep 2024 00:24:12 -0700 Subject: [PATCH 067/116] [Doc] Fix typo in AMD installation guide (#8689) --- docs/source/getting_started/amd-installation.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 9648d07d2790c..d169fe676dc94 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -83,7 +83,7 @@ Option 2: Build from source For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch-nightly`. -Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guild in PyTorch `Getting Started `_ +Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch `Getting Started `_ 1. Install `Triton flash attention for ROCm `_ @@ -104,7 +104,7 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases. $ cd vllm $ pip install -U -r requirements-rocm.txt - $ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation + $ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation .. tip:: From ec4aaad8124baadc7954e30c612ca9444b22d7e7 Mon Sep 17 00:00:00 2001 From: rasmith Date: Sat, 21 Sep 2024 04:20:54 -0500 Subject: [PATCH 068/116] [Kernel][Triton][AMD] Remove tl.atomic_add from awq_gemm_kernel, 2-5x speedup MI300, minor improvement for MI250 (#8646) --- .../layers/quantization/awq_triton.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index d0b210c3a2747..bbb7fc8ad5087 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -209,12 +209,9 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, c = accumulator.to(c_ptr.type.element_ty) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :] + c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if SPLIT_K == 1: - tl.store(c_ptrs, c, mask=c_mask) - else: - tl.atomic_add(c_ptrs, c, mask=c_mask) + tl.store(c_ptrs, c, mask=c_mask) # qweights - [K , M // 8], int32 @@ -295,7 +292,9 @@ def awq_gemm_triton(input: torch.Tensor, split_k_iters, ) - result = torch.zeros((M, N), dtype=scales.dtype, device=input.device) + result = torch.zeros((split_k_iters, M, N), + dtype=scales.dtype, + device=input.device) # A = input, B = qweight, C = result # A = M x K, B = K x N, C = M x N @@ -313,4 +312,6 @@ def awq_gemm_triton(input: torch.Tensor, BLOCK_SIZE_K=block_size_k, SPLIT_K=split_k_iters) + result = result.sum(0) + return result From 9dc7c6c7f332ac6c08311c7a946c6945e0782701 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Sat, 21 Sep 2024 16:09:39 -0500 Subject: [PATCH 069/116] [dbrx] refactor dbrx experts to extend FusedMoe class (#8518) --- vllm/model_executor/models/dbrx.py | 120 ++++++++++++----------------- 1 file changed, 51 insertions(+), 69 deletions(-) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 6160197dc19de..397a46a486f72 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -7,9 +7,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.fused_moe import fused_moe + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -22,7 +21,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -54,13 +52,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return router_logits -class DbrxExperts(nn.Module): - """A tensor-parallel MoE implementation for DBRX. - - Each expert's weights are sharded across all ranks and a fused MoE - kernel is used for the forward pass, and finally we reduce the outputs - across ranks. - """ +class DbrxExperts(FusedMoE): def __init__( self, @@ -68,49 +60,24 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, ): - super().__init__() + super().__init__( + num_experts=config.ffn_config.moe_num_experts, + top_k=config.ffn_config.moe_top_k, + hidden_size=config.d_model, + intermediate_size=config.ffn_config.ffn_hidden_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=get_tensor_model_parallel_world_size(), + ) + self.config = config self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.ffn_config.moe_num_experts - self.top_k = config.ffn_config.moe_top_k self.d_model = config.d_model - self.intermediate_size = (config.ffn_config.ffn_hidden_size // + self.intermediate_size = (self.config.ffn_config.ffn_hidden_size // self.tp_size) - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - - self.router = DbrxRouter(config, self.params_dtype) - self.ws = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.d_model, - device="cuda", - dtype=self.params_dtype, - )) - self.w2s = nn.Parameter( - torch.empty( - self.num_total_experts, - self.d_model, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype, - )) - - set_weight_attrs( - self.ws, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2s, - { - "weight_loader": self.weight_loader, - }, - ) - + # Define custom weight loader for dbrx model def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str): tp_rank = get_tensor_model_parallel_rank() @@ -140,26 +107,40 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, ).transpose(1, 2) param_data[:] = loaded_weight[:, :, shard] + +class DbrxMoE(nn.Module): + """A tensor-parallel MoE implementation for DBRX. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + config: DbrxConfig, + quant_config: Optional[QuantizationConfig] = None, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.d_model = config.d_model + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.router = DbrxRouter(config, self.params_dtype) + + self.experts = DbrxExperts(config=config, + quant_config=quant_config, + params_dtype=self.params_dtype) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.d_model) # router_logits: (num_tokens, n_experts) router_logits = self.router(hidden_states) - final_hidden_states = fused_moe( - hidden_states, - self.ws, - self.w2s, - router_logits, - self.top_k, - renormalize=True, - inplace=True, - ) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_size) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) class DbrxAttention(nn.Module): @@ -288,7 +269,7 @@ def __init__( super().__init__() self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, quant_config) - self.ffn = DbrxExperts(config, quant_config) + self.ffn = DbrxMoE(config, quant_config) def forward( self, @@ -409,9 +390,10 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + expert_params_mapping = [( - "ws" if weight_name in ["w1", "v1"] else "w2s", - f"experts.mlp.{weight_name}", + "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight", + f"mlp.{weight_name}", ) for weight_name in ["w1", "v1", "w2"]] params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: From d66ac62854e04c8fda83506dc93ef7971ebf593a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sat, 21 Sep 2024 19:45:02 -0400 Subject: [PATCH 070/116] [Kernel][Bugfix] Delete some more useless code in marlin_moe_ops.cu (#8643) --- csrc/moe/marlin_moe_ops.cu | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 49cc03f827f68..293a6fad72c2f 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1704,9 +1704,6 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, } #define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ \ __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ From 13d88d4137f97b8cf3c79f39d7df5e4c8348603a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 22 Sep 2024 12:33:27 +0800 Subject: [PATCH 071/116] [Bugfix] Refactor composite weight loading logic (#8656) --- vllm/model_executor/models/internvl.py | 16 ++++----- vllm/model_executor/models/llava.py | 16 ++++----- vllm/model_executor/models/llava_next.py | 20 ++++------- .../model_executor/models/llava_next_video.py | 17 ++++----- vllm/model_executor/models/paligemma.py | 14 +++----- vllm/model_executor/models/ultravox.py | 12 +++---- vllm/model_executor/models/utils.py | 36 ++++++++++++++++++- 7 files changed, 70 insertions(+), 61 deletions(-) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 507d7014714a2..005a24f10aa17 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -4,7 +4,6 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- -import itertools import re from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -33,8 +32,8 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) from .interfaces import SupportsMultiModal -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) +from .utils import (flatten_bn, group_weights_with_prefix, + init_vllm_registered_model, merge_multimodal_embeddings) IMG_START = '' IMG_END = '' @@ -518,21 +517,18 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_model") - self.vision_model.load_weights(vit_weights) + self.vision_model.load_weights(weights_group["vision_model"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "mlp1") mlp_params_dict = dict(self.mlp1.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["mlp1"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 7a6c991fb133a..69eb177a7dea8 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,4 +1,3 @@ -import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -26,8 +25,8 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) +from .utils import (flatten_bn, group_weights_with_prefix, + init_vllm_registered_model, merge_multimodal_embeddings) class LlavaImagePixelInputs(TypedDict): @@ -393,21 +392,18 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index d550a249ee822..96034b254e49b 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,4 +1,3 @@ -import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -30,8 +29,8 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, - merge_multimodal_embeddings) +from .utils import (flatten_bn, group_weights_with_prefix, + init_vllm_registered_model, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -637,25 +636,21 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( - weights, 4) + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load newline - newline_weights = filter_weights(newline_weights, "image_newline") - for name, loaded_weight in newline_weights: + for name, loaded_weight in weights_group["image_newline"]: assert name == "" param = self.image_newline weight_loader = getattr(param, "weight_loader", @@ -663,5 +658,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 7fe85e5e4ab3d..a8b5176dc43cf 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -1,4 +1,3 @@ -import itertools import math from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -30,7 +29,7 @@ from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip) -from .utils import (filter_weights, init_vllm_registered_model, +from .utils import (group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -449,23 +448,19 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators - vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( - weights, 4) + # prepare weight iterators for components + weights_group = group_weights_with_prefix(weights) # load vision encoder - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 5fd39b5e35be6..68b6d0cf808e1 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,4 +1,3 @@ -import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -23,7 +22,7 @@ from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) -from .utils import filter_weights, merge_multimodal_embeddings +from .utils import group_weights_with_prefix, merge_multimodal_embeddings logger = init_logger(__name__) @@ -286,21 +285,18 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + weights_group = group_weights_with_prefix(weights) # load vision tower - vit_weights = filter_weights(vit_weights, "vision_tower") - self.vision_tower.load_weights(vit_weights) + self.vision_tower.load_weights(weights_group["vision_tower"]) # load mlp projector - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in mlp_weights: + for name, loaded_weight in weights_group["multi_modal_projector"]: param = mlp_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 87f59f487f87b..b89c9dafd9cd8 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -1,7 +1,6 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" -import itertools import math from array import array from functools import lru_cache @@ -29,7 +28,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.utils import (filter_weights, flatten_bn, +from vllm.model_executor.models.utils import (flatten_bn, + group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -467,11 +467,10 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components - projector_weights, llm_weights = itertools.tee(weights, 2) + weights_group = group_weights_with_prefix(weights) # load projector weights - projector_weights = filter_weights(projector_weights, - "multi_modal_projector") + projector_weights = weights_group["multi_modal_projector"] projector_params_dict = dict( self.multi_modal_projector.named_parameters()) for name, loaded_weight in projector_weights: @@ -481,5 +480,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) # load llm backbone - llm_weights = filter_weights(llm_weights, "language_model") - self.language_model.load_weights(llm_weights) + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 8b80dda96db49..38d6a4653ebd6 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,3 +1,5 @@ +import itertools +from collections import UserDict from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, Union, overload) @@ -16,7 +18,23 @@ from vllm.utils import is_pin_memory_available -def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): +class WeightsGroup(UserDict): + """ + Wraps grouped weights dictionary for a more informative error message + when attempting to access a weight component that does not exist. + """ + + def __getitem__(self, key: str) -> int: + try: + return super().__getitem__(key) + except KeyError as exc: + msg = (f"There is no weights named with the prefix: {key}. " + f"Available prefix: {set(self.keys())}") + raise KeyError(msg) from exc + + +def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], + prefix: str) -> Iterable[Tuple[str, torch.Tensor]]: """ Helper function to load weights for inner vLLM models. @@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): yield name, loaded_weight +def group_weights_with_prefix( + weights: Iterable[Tuple[str, torch.Tensor]] +) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]: + """ + Helper function to group weights with prefix + """ + init_weights, repeated_weights = itertools.tee(weights, 2) + weights_prefix = {name.split(".")[0] for name, _ in init_weights} + repeated_weights = itertools.tee(repeated_weights, len(weights_prefix)) + + return WeightsGroup({ + prefix: filter_weights(component, prefix) + for component, prefix in zip(repeated_weights, weights_prefix) + }) + + def init_vllm_registered_model( hf_config: PretrainedConfig, cache_config: Optional[CacheConfig], From 0e40ac9b7b5d953dfe38933bc7d2fb0a6c8da53c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 21 Sep 2024 23:24:58 -0700 Subject: [PATCH 072/116] [ci][build] fix vllm-flash-attn (#8699) --- CMakeLists.txt | 3 +++ setup.py | 15 +++++++++++++++ vllm/vllm_flash_attn/.gitkeep | 0 3 files changed, 18 insertions(+) create mode 100644 vllm/vllm_flash_attn/.gitkeep diff --git a/CMakeLists.txt b/CMakeLists.txt index e0716af6fff4f..03937e4e0658b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -382,6 +382,9 @@ endif() # Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization. set(VLLM_PARENT_BUILD ON) +# Ensure the vllm/vllm_flash_attn directory exists before installation +install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c) + # Make sure vllm-flash-attn install rules are nested under vllm/ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c) install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) diff --git a/setup.py b/setup.py index cc559f26c6f3f..60e31af0a8d39 100644 --- a/setup.py +++ b/setup.py @@ -258,6 +258,21 @@ def build_extensions(self) -> None: ] subprocess.check_call(install_args, cwd=self.build_temp) + def run(self): + # First, run the standard build_ext command to compile the extensions + super().run() + + # copy vllm/vllm_flash_attn/*.py from self.build_lib to current + # directory so that they can be included in the editable build + import glob + files = glob.glob( + os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py")) + for file in files: + dst_file = os.path.join("vllm/vllm_flash_attn", + os.path.basename(file)) + print(f"Copying {file} to {dst_file}") + self.copy_file(file, dst_file) + def _no_device() -> bool: return VLLM_TARGET_DEVICE == "empty" diff --git a/vllm/vllm_flash_attn/.gitkeep b/vllm/vllm_flash_attn/.gitkeep new file mode 100644 index 0000000000000..e69de29bb2d1d From 06ed2815e2be50e527839c7ab09ce2639b7910b6 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 22 Sep 2024 20:24:21 +0800 Subject: [PATCH 073/116] [Model] Refactor BLIP/BLIP-2 to support composite model loading (#8407) --- vllm/model_executor/models/blip.py | 61 ++++++++- vllm/model_executor/models/blip2.py | 121 +++++++----------- vllm/model_executor/models/chameleon.py | 3 - vllm/model_executor/models/clip.py | 11 +- vllm/model_executor/models/fuyu.py | 3 - vllm/model_executor/models/llava_next.py | 8 -- .../model_executor/models/llava_next_video.py | 3 - vllm/model_executor/models/minicpmv.py | 3 - vllm/model_executor/models/siglip.py | 11 +- vllm/model_executor/models/ultravox.py | 3 - 10 files changed, 113 insertions(+), 114 deletions(-) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index e943427eda8e1..7c8e76461dd67 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -1,6 +1,6 @@ """Minimal implementation of BlipVisionModel intended to be only used within a vision language model.""" -from typing import Optional, Union +from typing import Iterable, Optional, Tuple, Union import torch import torch.nn as nn @@ -16,6 +16,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData @@ -342,6 +343,10 @@ def __init__(self, num_hidden_layers_override: Optional[int] = None): super().__init__() + tp_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads + self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 + self.config = config self.embeddings = BlipVisionEmbeddings(config) @@ -350,11 +355,61 @@ def __init__(self, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, ) - self.post_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {config.num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + elif len(self.encoder.layers) == config.num_hidden_layers: + self.post_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + else: + # post_layernorm is unused when we extract intermediate features + # In this case, we can skip it to conserve memory + self.post_layernorm = None def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: hidden_states = self.embeddings(pixel_values) hidden_states = self.encoder(inputs_embeds=hidden_states) + if self.post_layernorm is None: + return hidden_states + return self.post_layernorm(hidden_states) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] if self.shard_weight else [] + params_dict = dict(self.named_parameters()) + layer_count = len(self.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in BlipVisionModel + if (name.startswith("post_layernorm") + and self.post_layernorm is None): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("encoder.layers"): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 37fabf3f3f9a8..b28d7699afa01 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -10,11 +10,9 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SequenceData @@ -22,12 +20,8 @@ from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) from .interfaces import SupportsMultiModal -from .utils import merge_multimodal_embeddings - -_KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", -} +from .utils import (group_weights_with_prefix, init_vllm_registered_model, + merge_multimodal_embeddings) # We use this internally as placeholders since there is no image token # defined on the HuggingFace repo @@ -491,9 +485,6 @@ def __init__(self, super().__init__() - # currently all existing BLIP-2 models have `tie_word_embeddings` - # enabled - assert config.tie_word_embeddings self.config = config self.multimodal_config = multimodal_config @@ -514,17 +505,8 @@ def __init__(self, bias=True, ) - self.quant_config = quant_config - - self.language_model = OPTModel(config.text_config, cache_config, - quant_config) - - self.unpadded_vocab_size = config.text_config.vocab_size - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size) - self.sampler = Sampler() - - def get_lm_head(self): - return self.language_model.decoder.embed_tokens + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size @@ -653,7 +635,8 @@ def forward( if image_input is not None: vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, @@ -663,11 +646,11 @@ def forward( else: inputs_embeds = None - hidden_states = self.language_model(input_ids, - positions, - kv_caches, - attn_metadata, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + inputs_embeds=inputs_embeds) return hidden_states @@ -676,56 +659,46 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.get_lm_head(), hidden_states, - sampling_metadata) - return logits + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # only doing this for language model part for now. - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - - for name, loaded_weight in weights: - if "lm_head.weight" in name: - continue - if "rotary_emb.inv_freq" in name: - continue - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - use_default_weight_loading = False - if "vision" in name: - if self.vision_model is not None: - # BlipVisionModel does not need sharding - use_default_weight_loading = True - else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - use_default_weight_loading = True - if use_default_weight_loading: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + # prepare weight iterators for components + weights_group = group_weights_with_prefix(weights) + + # load vision encoder + self.vision_model.load_weights(weights_group["vision_model"]) + + # load query tokens + for name, loaded_weight in weights_group["query_tokens"]: + assert name == "" + param = self.query_tokens + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load qformer + qformer_params_dict = dict(self.qformer.named_parameters()) + for name, loaded_weight in weights_group["qformer"]: + param = qformer_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load mlp projector + mlp_params_dict = dict(self.language_projection.named_parameters()) + for name, loaded_weight in weights_group["language_projection"]: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 51a61485caf65..973e47f5f0ccd 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -12,7 +12,6 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -36,8 +35,6 @@ from .interfaces import SupportsMultiModal -logger = init_logger(__name__) - # These configs are not part of the model config but the preprocessor # and processor files, so we hardcode them in the model file for now. CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512 diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index a7754f70e2786..c353635404d9a 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -391,6 +391,7 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, num_hidden_layers_override: Optional[int] = None): super().__init__() + tp_size = get_tensor_model_parallel_world_size() num_heads = config.num_attention_heads self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 @@ -400,10 +401,6 @@ def __init__(self, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override) - @property - def _require_post_layernorm(self) -> bool: - return self.vision_model.post_layernorm is not None - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return self.vision_model(pixel_values) @@ -425,12 +422,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: # post_layernorm is not needed in CLIPVisionModel - if ("vision_model.post_layernorm" in name - and not self._require_post_layernorm): + if (name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None): continue # omit layers when num_hidden_layers_override is set - if "vision_model.encoder.layers." in name: + if name.startswith("vision_model.encoder.layers"): layer_idx = int(name.split(".")[3]) if layer_idx >= layer_count: continue diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index beeae14229575..4cf3b0b93dcf5 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -28,7 +28,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -45,8 +44,6 @@ from .interfaces import SupportsMultiModal from .utils import merge_multimodal_embeddings -logger = init_logger(__name__) - # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 _NEWLINE_TOKEN_ID = 71019 diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 96034b254e49b..4341cc38bdd28 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -12,7 +12,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -32,13 +31,6 @@ from .utils import (flatten_bn, group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) -logger = init_logger(__name__) - -_KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", -} - # Result in the max possible feature size (2x2 grid of 336x336px tiles) MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index a8b5176dc43cf..397a6cce5af2c 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -11,7 +11,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -32,8 +31,6 @@ from .utils import (group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) -logger = init_logger(__name__) - # For profile run _MAX_FRAMES_PER_VIDEO = 32 _MAX_NUM_VIDEOS = 1 diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 5579205832aa8..c0fb6fef78bab 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -37,7 +37,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -59,8 +58,6 @@ from .idefics2_vision_model import Idefics2VisionTransformer -logger = init_logger(__name__) - _KEYS_TO_MODIFY_MAPPING = { "llm.lm_head": "lm_head", "llm.model": "llm", diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 5b332fa1a24d7..6cf7df4e6ac63 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -501,6 +501,7 @@ def __init__( num_hidden_layers_override: Optional[int] = None, ): super().__init__() + num_heads = config.num_attention_heads tp_size = get_tensor_model_parallel_world_size() self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 @@ -511,10 +512,6 @@ def __init__( num_hidden_layers_override=num_hidden_layers_override, ) - @property - def _require_post_layernorm(self) -> bool: - return self.vision_model.post_layernorm is not None - def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @@ -540,12 +537,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: # post_layernorm is optional in SiglipVisionModel - if ("vision_model.post_layernorm" in name - and not self._require_post_layernorm): + if (name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None): continue # omit layers when num_hidden_layers_override is set - if "vision_model.encoder.layers." in name: + if name.startswith("vision_model.encoder.layers"): layer_idx = int(name.split(".")[3]) if layer_idx >= layer_count: continue diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b89c9dafd9cd8..32a0e895005cb 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -20,7 +20,6 @@ from vllm.inputs import INPUT_REGISTRY from vllm.inputs.data import LLMInputs from vllm.inputs.registry import InputContext -from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.base_config import ( @@ -43,8 +42,6 @@ _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 -logger = init_logger(__name__) - class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] From 8ca5051b9afb6f8d2b3ae1b71d45d84e5d1c6f57 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 22 Sep 2024 06:56:20 -0600 Subject: [PATCH 074/116] [Misc] Use NamedTuple in Multi-image example (#8705) Signed-off-by: Alex-Brooks --- ...e_inference_vision_language_multi_image.py | 74 +++++++++++++------ 1 file changed, 52 insertions(+), 22 deletions(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 454872c628373..92ab4f42baa80 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -4,8 +4,9 @@ by the model. """ from argparse import Namespace -from typing import List +from typing import List, NamedTuple, Optional +from PIL.Image import Image from transformers import AutoProcessor, AutoTokenizer from vllm import LLM, SamplingParams @@ -19,7 +20,15 @@ ] -def load_qwenvl_chat(question: str, image_urls: List[str]): +class ModelRequestData(NamedTuple): + llm: LLM + prompt: str + stop_token_ids: Optional[List[str]] + image_data: List[Image] + chat_template: Optional[str] + + +def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "Qwen/Qwen-VL-Chat" llm = LLM( model=model_name, @@ -48,10 +57,16 @@ def load_qwenvl_chat(question: str, image_urls: List[str]): stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] - return llm, prompt, stop_token_ids, None, chat_template + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=chat_template, + ) -def load_phi3v(question: str, image_urls: List[str]): +def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: llm = LLM( model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, @@ -62,10 +77,17 @@ def load_phi3v(question: str, image_urls: List[str]): for i, _ in enumerate(image_urls, start=1)) prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" stop_token_ids = None - return llm, prompt, stop_token_ids, None, None + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) -def load_internvl(question: str, image_urls: List[str]): +def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "OpenGVLab/InternVL2-2B" llm = LLM( @@ -93,10 +115,16 @@ def load_internvl(question: str, image_urls: List[str]): stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] - return llm, prompt, stop_token_ids, None, None + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) -def load_qwen2_vl(question, image_urls: List[str]): +def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: try: from qwen_vl_utils import process_vision_info except ModuleNotFoundError: @@ -143,7 +171,13 @@ def load_qwen2_vl(question, image_urls: List[str]): else: image_data, _ = process_vision_info(messages) - return llm, prompt, stop_token_ids, image_data, None + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=image_data, + chat_template=None, + ) model_example_map = { @@ -155,20 +189,17 @@ def load_qwen2_vl(question, image_urls: List[str]): def run_generate(model, question: str, image_urls: List[str]): - llm, prompt, stop_token_ids, image_data, _ = model_example_map[model]( - question, image_urls) - if image_data is None: - image_data = [fetch_image(url) for url in image_urls] + req_data = model_example_map[model](question, image_urls) sampling_params = SamplingParams(temperature=0.0, max_tokens=128, - stop_token_ids=stop_token_ids) + stop_token_ids=req_data.stop_token_ids) - outputs = llm.generate( + outputs = req_data.llm.generate( { - "prompt": prompt, + "prompt": req_data.prompt, "multi_modal_data": { - "image": image_data + "image": req_data.image_data }, }, sampling_params=sampling_params) @@ -179,13 +210,12 @@ def run_generate(model, question: str, image_urls: List[str]): def run_chat(model: str, question: str, image_urls: List[str]): - llm, _, stop_token_ids, _, chat_template = model_example_map[model]( - question, image_urls) + req_data = model_example_map[model](question, image_urls) sampling_params = SamplingParams(temperature=0.0, max_tokens=128, - stop_token_ids=stop_token_ids) - outputs = llm.chat( + stop_token_ids=req_data.stop_token_ids) + outputs = req_data.llm.chat( [{ "role": "user", @@ -203,7 +233,7 @@ def run_chat(model: str, question: str, image_urls: List[str]): ], }], sampling_params=sampling_params, - chat_template=chat_template, + chat_template=req_data.chat_template, ) for o in outputs: From ca2b628b3c25b014b9951731c0331b75262a59e0 Mon Sep 17 00:00:00 2001 From: Huazhong Ji Date: Mon, 23 Sep 2024 01:44:09 +0800 Subject: [PATCH 075/116] [MISC] rename CudaMemoryProfiler to DeviceMemoryProfiler (#8703) --- vllm/utils.py | 2 +- vllm/worker/model_runner.py | 4 ++-- vllm/worker/xpu_model_runner.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 43b64263d645a..b1513b91a06c6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -757,7 +757,7 @@ def is_pin_memory_available() -> bool: return True -class CudaMemoryProfiler: +class DeviceMemoryProfiler: def __init__(self, device: Optional[torch.types.Device] = None): self.device = device diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e8c472df8b5fc..0a90f767567d6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -45,7 +45,7 @@ LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, +from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available, supports_dynamo) from vllm.worker.model_runner_base import ( @@ -1012,7 +1012,7 @@ def __init__( def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) - with CudaMemoryProfiler() as m: + with DeviceMemoryProfiler() as m: self.model = get_model(model_config=self.model_config, device_config=self.device_config, load_config=self.load_config, diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index f9037625d4af9..d3c763c995b34 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -21,7 +21,7 @@ MultiModalInputs, MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad +from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -391,7 +391,7 @@ def __init__( self.model: nn.Module # Set after init_Model def load_model(self) -> None: - with CudaMemoryProfiler() as m: + with DeviceMemoryProfiler() as m: self.model = get_model( model_config=self.model_config, device_config=self.device_config, From 5b59532760c82a9d91f65a3e227524da2af7d4ef Mon Sep 17 00:00:00 2001 From: litianjian <45817262+litianjian@users.noreply.github.com> Date: Mon, 23 Sep 2024 01:51:44 +0800 Subject: [PATCH 076/116] [Model][VLM] Add LLaVA-Onevision model support (#8486) Co-authored-by: litianjian Co-authored-by: Cyrus Leung Co-authored-by: Roger Wang Co-authored-by: DarkLight1337 --- docs/source/models/supported_models.rst | 7 +- examples/offline_inference_vision_language.py | 60 +- .../vision_language/test_llava_next_video.py | 3 - .../vision_language/test_llava_onevision.py | 356 +++++++ tests/models/test_registry.py | 3 +- vllm/assets/video.py | 2 +- vllm/model_executor/models/__init__.py | 6 +- vllm/model_executor/models/clip.py | 19 + vllm/model_executor/models/llava_onevision.py | 876 ++++++++++++++++++ vllm/model_executor/models/siglip.py | 19 + 10 files changed, 1330 insertions(+), 21 deletions(-) create mode 100644 tests/models/decoder_only/vision_language/test_llava_onevision.py create mode 100644 vllm/model_executor/models/llava_onevision.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 9e0303e1dab6c..d86d0860f7f29 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -244,6 +244,11 @@ Multimodal Language Models - Video - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. (see note) - + * - :code:`LlavaOnevisionForConditionalGeneration` + - LLaVA-Onevision + - Image\ :sup:`+` / Video + - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. (see note) + - * - :code:`MiniCPMV` - MiniCPM-V - Image\ :sup:`+` @@ -288,7 +293,7 @@ Multimodal Language Models For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 .. note:: - For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. + For :code:`LLaVA-NeXT-Video`, :code:`LLaVA-Onevision` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. This can be installed by running the following command: .. code-block:: bash diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 464eaf334e3de..c1129316a6e30 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -14,7 +14,8 @@ # LLaVA-1.5 -def run_llava(question): +def run_llava(question, modality): + assert modality == "image" prompt = f"USER: \n{question}\nASSISTANT:" @@ -24,7 +25,8 @@ def run_llava(question): # LLaVA-1.6/LLaVA-NeXT -def run_llava_next(question): +def run_llava_next(question, modality): + assert modality == "image" prompt = f"[INST] \n{question} [/INST]" llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192) @@ -34,15 +36,35 @@ def run_llava_next(question): # LlaVA-NeXT-Video # Currently only support for video input -def run_llava_next_video(question): +def run_llava_next_video(question, modality): + assert modality == "video" + prompt = f"USER: