From 66067a620031ebe7174a3c10c6e55113b0b11655 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Sun, 24 Dec 2023 21:01:01 +0800 Subject: [PATCH 1/5] Add initial support --- csrc/ops.h | 6 +- csrc/quantization/gptq/matrix_view.cuh | 72 ++ csrc/quantization/gptq/q_gemm.cu | 1027 ++++++++++++++++- csrc/quantization/gptq/qdq_2.cuh | 87 ++ csrc/quantization/gptq/qdq_4.cuh | 100 +- csrc/quantization/gptq/qdq_8.cuh | 40 + .../layers/quantization/gptq.py | 10 +- 7 files changed, 1179 insertions(+), 163 deletions(-) create mode 100644 csrc/quantization/gptq/qdq_2.cuh create mode 100644 csrc/quantization/gptq/qdq_8.cuh diff --git a/csrc/ops.h b/csrc/ops.h index 9340a60da1417..e0a3080982059 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -84,8 +84,10 @@ torch::Tensor gptq_gemm( torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama); + bool use_exllama, + int bit); void gptq_shuffle( torch::Tensor q_weight, - torch::Tensor q_perm); + torch::Tensor q_perm, + int bit); diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/quantization/gptq/matrix_view.cuh index 1fdf019b29028..e0a8c274e5fc6 100644 --- a/csrc/quantization/gptq/matrix_view.cuh +++ b/csrc/quantization/gptq/matrix_view.cuh @@ -146,6 +146,78 @@ public: __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } }; +class MatrixView_q2_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x0f) * 2; + return (data[row * width / 16 + column / 16] >> shift) & 0x03; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const + { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const + { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + items[2] = (d >> 4) & 0x03; + items[3] = (d >> 6) & 0x03; + } +}; + +class MatrixView_q8_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x03) * 8; + return (data[row * width / 4 + column / 4] >> shift) & 0xff; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const + { + int shift = (column & 0x03) * 8; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const + { + int shift = (column & 0x03) * 2; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + items[2] = (d >> 16) & 0xff; + items[3] = (d >> 24) & 0xff; + } +}; + } // namespace gptq } // namespace vllm #endif diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index eb0d75f1293c4..f964d60278bef 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -13,7 +13,9 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq #include "compat.cuh" #include "matrix_view.cuh" +#include "qdq_2.cuh" #include "qdq_4.cuh" +#include "qdq_8.cuh" namespace vllm { namespace gptq { @@ -75,6 +77,106 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) return __half2float(__low2half(result)) + __half2float(__high2half(result)); } +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); +} + +__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) +{ + // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 + + float result = {}; + #pragma unroll + for (int i = 0; i < 4; i++) + { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); +} + +__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + +__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); +} + + typedef void (*fp_gemm_half_q_half_gptq_kernel) ( const half*, @@ -89,8 +191,9 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel) const int* ); + template -__global__ void gemm_half_q_half_gptq_kernel +__global__ void gemm_half_q_half_gptq_4bit_kernel ( const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, @@ -231,32 +334,297 @@ __global__ void gemm_half_q_half_gptq_kernel } } +template +__global__ void gemm_half_q_half_gptq_2bit_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 2); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + + #pragma unroll + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + + b_ptr += size_n; + a_ptr += 16; + } + + k += 16; + } + + for (int m = 0; m < m_count; m++) + { + half2 *out = (half2*) c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + +template +__global__ void gemm_half_q_half_gptq_8bit_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm +) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 8); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + int4 load_int4[2]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); + + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 8; + } + k += 32; + } + + for (int m = 0; m < m_count; m++) + { + half2 *out = (half2*) c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel( + bool first_block, const int m_count, const int bit) +{ + #define SELECT_KERNEL(M_COUNT) \ + if (m_count == M_COUNT) { \ + if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ + if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ + if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ + } #if BLOCK_M_SIZE_MAX >= 1 - if (m_count == 1) return gemm_half_q_half_gptq_kernel; + SELECT_KERNEL(1); #endif #if BLOCK_M_SIZE_MAX >= 2 - if (m_count == 2) return gemm_half_q_half_gptq_kernel; + SELECT_KERNEL(2); #endif #if BLOCK_M_SIZE_MAX >= 3 - if (m_count == 3) return gemm_half_q_half_gptq_kernel; + SELECT_KERNEL(3); #endif #if BLOCK_M_SIZE_MAX >= 4 - if (m_count == 4) return gemm_half_q_half_gptq_kernel; + SELECT_KERNEL(4); #endif #if BLOCK_M_SIZE_MAX >= 5 - if (m_count == 5) return gemm_half_q_half_gptq_kernel; + SELECT_KERNEL(5); #endif #if BLOCK_M_SIZE_MAX >= 6 - if (m_count == 6) return gemm_half_q_half_gptq_kernel; + SELECT_KERNEL(6); #endif #if BLOCK_M_SIZE_MAX >= 7 - if (m_count == 7) return gemm_half_q_half_gptq_kernel; + SELECT_KERNEL(7); #endif #if BLOCK_M_SIZE_MAX >= 8 - if (m_count == 8) return gemm_half_q_half_gptq_kernel; + SELECT_KERNEL(8); #endif return NULL; } @@ -274,7 +642,8 @@ void gemm_half_q_half_cuda_part int size_n, int size_k, int m_count, - int groups + int groups, + int bit ) { dim3 blockDim, gridDim; @@ -298,12 +667,114 @@ void gemm_half_q_half_cuda_part size_n, size_k, groups, - b_q_perm + b_q_perm, ); } -__global__ void reconstruct_exllama_kernel +__global__ void reconstruct_exllama_8bit_kernel +( + const uint32_t* __restrict__ b_q_weight, + const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + const int size_k, + const int size_n, + const int groups, + half* __restrict__ b +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) + { + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 8); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 4; p++) + { + int4 load_int4[2]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); + + //half* dqh = (half*)dq; + if (b_q_perm) + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + else + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +__global__ void reconstruct_exllama_4bit_kernel ( const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, @@ -373,28 +844,130 @@ __global__ void reconstruct_exllama_kernel nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) + { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + b_ptr += size_n; + //half* dqh = (half*)dq; + if (b_q_perm) + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + else + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +__global__ void reconstruct_exllama_2bit_kernel +( + const uint32_t* __restrict__ b_q_weight, + const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + const int size_k, + const int size_n, + const int groups, + half* __restrict__ b +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) + { + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 2); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); } - for (int p = 0; p < 4; p++) + for (int p = 0; p < 2; p++) { - half2 dq[4][4]; const int4* b_ptr4 = (int4*) b_ptr; int4 load_int4 = *b_ptr4; - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); b_ptr += size_n; //half* dqh = (half*)dq; if (b_q_perm) { - for (int j = 0; j < 4; j++) + for (int j = 0; j < 8; j++) { for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); @@ -403,8 +976,9 @@ __global__ void reconstruct_exllama_kernel } else { - for (int j = 0; j < 4; j++) + for (int j = 0; j < 8; j++) { + printf("Debug %d %d %d\n", offset_k, lk, n); for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); @@ -415,7 +989,6 @@ __global__ void reconstruct_exllama_kernel } } - void reconstruct_exllama ( const uint32_t* b_q_weight, @@ -425,7 +998,8 @@ void reconstruct_exllama half* out, int height, int width, - int groups + int groups, + int bit ) { dim3 blockDim, gridDim; @@ -434,6 +1008,14 @@ void reconstruct_exllama gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; + if (bit == 2) { + reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; + } + if (bit == 8) { + reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + } + reconstruct_exllama_kernel<<>> ( b_q_weight, @@ -448,7 +1030,7 @@ void reconstruct_exllama } -__global__ void gemm_half_q_half_alt_kernel( +__global__ void gemm_half_q_half_alt_4bit_kernel( const half2* __restrict__ vec, const uint32_t* __restrict__ mat, half* __restrict__ mul, @@ -545,6 +1127,195 @@ __global__ void gemm_half_q_half_alt_kernel( } } +__global__ void gemm_half_q_half_alt_2bit_kernel( + const half2* __restrict__ vec, + const uint32_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int height, + int width +) +{ + int zero_width = width / 16; + int vec_height = height * 8; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 16; + int h_end = min(BLOCK_KN_SIZE / 16, height - h) * 8; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; + } + } + + __shared__ half2 deq2[16][16]; + int val = threadIdx.x / 16; + int off = threadIdx.x % 16; + for (; val < 16; val += BLOCK_KN_SIZE / 16) { + deq2[val][off] = __halves2half2( + __int2half_rn(val & 0x3), __int2half_rn(val >> 2) + ); + } + + if (blockIdx.z == 0) + { + for (int m = 0; m < b_end; m++) + mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 16; + int k = 0; + int z_w = w / 16; + int z_mod = (w % 16) * 2 + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[8]; + half2 zeros_tmp[8]; + for (int tmp_k = 0; tmp_k < 8; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0x03) - 1)), + __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0x03) - 1)) + ); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + res2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xf][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 12) & 0xf][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xf][off], scales_tmp[4], zeros_tmp[4]), blockvec[m][k + 4], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 20) & 0xf][off], scales_tmp[5], zeros_tmp[5]), blockvec[m][k + 5], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xf][off], scales_tmp[6], zeros_tmp[6]), blockvec[m][k + 6], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scales_tmp[7], zeros_tmp[7]), blockvec[m][k + 7], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 8; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + +__global__ void gemm_half_q_half_alt_8bit_kernel( + const half2* __restrict__ vec, + const uint32_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int height, + int width +) +{ + int zero_width = width / 4; + int vec_height = height * 2; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 4; + int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; + } + } + + + if (blockIdx.z == 0) + { + for (int m = 0; m < b_end; m++) + mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 4; + int k = 0; + int z_w = w / 4; + int z_mod = (w % 4) * 8; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[2]; + half2 zeros_tmp[2]; + for (int tmp_k = 0; tmp_k < 2; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), + __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1)) + ); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + res2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), __int2half_rn((tmp >> 8) & 0xFF)); + res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); + half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), __int2half_rn((tmp >> 24) & 0xFF)); + res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 2; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} void gemm_half_q_half_alt ( @@ -556,7 +1327,8 @@ void gemm_half_q_half_alt half* c, int size_m, int size_n, - int size_k + int size_k, + int bit ) { dim3 blockDim, gridDim; @@ -567,7 +1339,14 @@ void gemm_half_q_half_alt gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); - gemm_half_q_half_alt_kernel<<>> + auto kernel = gemm_half_q_half_alt_4bit_kernel; + if (bit == 2) { + kernel = gemm_half_q_half_alt_2bit_kernel; + } else if (bit == 8) { + kernel = gemm_half_q_half_alt_8bit_kernel; + } + + kernel<<>> ( (const half2*) a, b_q_weight, @@ -576,12 +1355,12 @@ void gemm_half_q_half_alt b_gptq_qzeros, b_g_idx, size_m, - size_k / 8, + size_k / 32 * bit, size_n ); } - +template __global__ void reconstruct_gptq_kernel ( const uint32_t* __restrict__ w, @@ -597,26 +1376,25 @@ __global__ void reconstruct_gptq_kernel // Start of block int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - int row = blockIdx.y * 8; + int row = blockIdx.y * 32 / bit; if (column >= width) return; // Views - MatrixView_q4_column w_(w, height, width); MatrixView_half_rw out_(out, height, width); MatrixView_half w_scales_(w_scales, group, width); - MatrixView_q4_row w_zeros_(w_zeros, group, width); + T w_zeros_(w_zeros, group, width); - uint32_t w_read = w_.item_uint32_t(row, column); + uint32_t w_read = w[blockIdx.y * width + column]; half* out_ptr = out_.item_ptr(row, column); #pragma unroll - for (int s = 0; s < 32; s += 4) + for (int s = 0; s < 32; s += bit) { - int group = g_idx[row + s / 4]; + int group = g_idx[row + s / bit]; half w_scale = w_scales_.item(group, column); uint32_t w_zero = w_zeros_.item(group, column) + 1; - half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale); *out_ptr = w_item; out_ptr += out_.width; } } @@ -631,15 +1409,24 @@ void reconstruct_gptq half* out, int height, int width, - int groups + int groups, + int bit ) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; - gridDim.y = DIVIDE(height, 8); + gridDim.y = DIVIDE(height, 32 / bit); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - reconstruct_gptq_kernel<<>> + + auto kernel = reconstruct_gptq_kernel; + if (bit == 2) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 8) { + kernel = reconstruct_gptq_kernel; + } + + kernel<<>> ( b_q_weight, b_gptq_scales, @@ -667,19 +1454,20 @@ void gemm_half_q_half_cuda int size_n, int size_k, int groups, - bool use_exllama + bool use_exllama, + int bit ) { if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) { // Reconstruct FP16 matrix, then cuBLAS if (use_exllama) { reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, - size_k, size_n, groups); + size_k, size_n, groups, bit); } else { reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups); + temp_dq, size_k, size_n, groups, bit); } const half alpha = __float2half(1.0f); @@ -703,7 +1491,7 @@ void gemm_half_q_half_cuda { gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, - groups); + groups, bit); } if (last_chunk_size) @@ -711,18 +1499,17 @@ void gemm_half_q_half_cuda gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, - groups); + groups, bit); } } else { gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, size_m, size_n, size_k); + c, size_m, size_n, size_k, bit); } } - -__global__ void shuffle_kernel +__global__ void shuffle_4bit_kernel ( uint32_t* __restrict__ b_q_weight, const int size_k, @@ -736,13 +1523,39 @@ __global__ void shuffle_kernel while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } } +__global__ void shuffle_8bit_kernel +( + uint32_t* __restrict__ b_q_weight, + const int size_k, + const int size_n +) +{ + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; } +} + +__global__ void shuffle_2bit_kernel +( + uint32_t* __restrict__ b_q_weight, + const int size_k, + const int size_n +) +{ + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; } +} -__global__ void make_sequential_kernel +__global__ void make_sequential_4bit_kernel ( const uint32_t* __restrict__ w, uint32_t* __restrict__ w_new, const int* __restrict__ q_perm, - const int w_height, const int w_width ) { @@ -774,36 +1587,114 @@ __global__ void make_sequential_kernel w_new2[w_new2_row * w2_stride + w2_column] = dst; } +__global__ void make_sequential_2bit_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 4; + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 16; i++) + { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 4; + int w2_subrow = source_row & 0x0f; + int w2_row_shift = w2_subrow << 1; + int wnew2_row_shift = i << 1; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000300000003; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +__global__ void make_sequential_8bit_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 2; + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 4; i++) + { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 2; + int w2_subrow = source_row & 0xff; + int w2_row_shift = w2_subrow << 3; + int wnew2_row_shift = i << 3; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x000000ff000000ff; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + void shuffle_exllama_weight ( uint32_t* q_weight, int* q_perm, int height, - int width + int width, + int bit ) { if (q_perm) { uint32_t* new_qweight = NULL; - cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t)); dim3 blockDim, gridDim; blockDim.x = THREADS_X; blockDim.y = 1; gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = height / 8; + gridDim.y = height / 32 * bit; - make_sequential_kernel<<>> + auto kernel = make_sequential_4bit_kernel; + if (bit == 2) { + kernel = make_sequential_2bit_kernel; + } else if (bit == 8) { + kernel = make_sequential_8bit_kernel; + } + kernel<<>> ( q_weight, new_qweight, q_perm, - height / 8, width ); // Replace qweights - cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + cudaMemcpyAsync(q_weight, new_qweight, height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); // Cleanup cudaDeviceSynchronize(); cudaFree(new_qweight); @@ -813,6 +1704,12 @@ void shuffle_exllama_weight blockDim.y = 1; gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = 1; + auto shuffle_kernel = shuffle_4bit_kernel; + if (bit == 2) { + shuffle_kernel = shuffle_2bit_kernel; + } else if (bit == 8) { + shuffle_kernel = shuffle_8bit_kernel; + } shuffle_kernel<<>>(q_weight, height, width); } @@ -826,13 +1723,14 @@ torch::Tensor gptq_gemm torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama + bool use_exllama, + int bit ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); - at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options); + at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); vllm::gptq::gemm_half_q_half_cuda ( @@ -848,7 +1746,8 @@ torch::Tensor gptq_gemm c.size(1), // n a.size(1), // k b_gptq_qzeros.size(0), // group number - use_exllama + use_exllama, + bit ); return c; } @@ -856,7 +1755,8 @@ torch::Tensor gptq_gemm void gptq_shuffle ( torch::Tensor q_weight, - torch::Tensor q_perm + torch::Tensor q_perm, + int bit ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); @@ -864,6 +1764,7 @@ void gptq_shuffle (uint32_t*) q_weight.data_ptr(), q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), q_weight.size(0) * 8, - q_weight.size(1) + q_weight.size(1), + bit ); } diff --git a/csrc/quantization/gptq/qdq_2.cuh b/csrc/quantization/gptq/qdq_2.cuh new file mode 100644 index 0000000000000..fd78670ea1f10 --- /dev/null +++ b/csrc/quantization/gptq/qdq_2.cuh @@ -0,0 +1,87 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_2_cuh +#define _qdq_2_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { + +// Permutation: +// +// ffddbb99 77553311 eeccaa88 66442200 + +__forceinline__ __device__ void shuffle_2bit_16 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_2bit_16 +( + const uint32_t q_0, + half2 (&dq)[8], + int stride, + const uint32_t zero +) +{ + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); + + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z4 = __half2half2(z4_); + const half2 z16 = __half2half2(z16_); + const half2 z64 = __half2half2(z64_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); +} + +} // namespace gptq +} // namespace vllm + +#endif \ No newline at end of file diff --git a/csrc/quantization/gptq/qdq_4.cuh b/csrc/quantization/gptq/qdq_4.cuh index cfc4635a22c1d..881f353f6564d 100644 --- a/csrc/quantization/gptq/qdq_4.cuh +++ b/csrc/quantization/gptq/qdq_4.cuh @@ -38,16 +38,17 @@ __forceinline__ __device__ void dequant_4bit_8 ( const uint32_t q_0, half2 (&dq)[4], - int stride + int stride, + const uint32_t zero ) { const uint32_t c0 = 0x64006400; const half y16_ = __float2half_rn(1.0f / 16.0f); const half2 y16 = __halves2half2(y16_, y16_); - const half z1_ = __float2half_rn(-1024.0f - 8.0f); - const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); - const half2 z1 = __halves2half2(z1_, z1_); - const half2 z16 = __halves2half2(z16_, z16_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z16 = __half2half2(z16_); uint32_t qa = q_0; half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 @@ -143,93 +144,4 @@ __forceinline__ __device__ void dequant_4bit_8_gptq } // namespace gptq } // namespace vllm -#else - -namespace vllm { -namespace gptq { -__forceinline__ __device__ void shuffle_4bit_8 -( - uint32_t* q, - int stride -) -{ -} - -__forceinline__ __device__ void dequant_4bit_8 -( - const uint32_t q_0, - half2 (&dq)[4], - int stride -) -{ - half dqh[8]; - for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8); - - for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); -} - -__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale -( - const uint32_t zero, - const half scale, - half2 (&z1)[2], - half2 (&y1)[2] -) -{ - half z = __int2half_rn(-((int)zero)); - z = __hmul(z, scale); - z1[0] = __half2half2(z); - y1[0] = __half2half2(scale); -} - -__forceinline__ __device__ void dequant_4bit_8_prep_zero -( - const uint32_t zero, - half2(&z1)[2], - half2(&y1)[2] -) -{ - half z = __int2half_rn(-((int)zero)); - z1[0] = __half2half2(z); -} - -__forceinline__ __device__ void dequant_4bit_8_gptq -( - const uint32_t q_0, - half2 (&dq)[4], - half2 (&z1)[2], - half2 (&y1)[2], - int stride, - bool scaled -) -{ - half2 dqh2[8]; - - uint32_t qa = q_0; - for (int i = 0; i < 4; i++) - { - half d0 = __int2half_rn(qa & 0x0f); qa >>= 4; - half d1 = __int2half_rn(qa & 0x0f); qa >>= 4; - dqh2[i] = __halves2half2(d0, d1); - } - - if (scaled) - { - dq[0] = __hfma2(dqh2[0], y1[0], z1[0]); - dq[1] = __hfma2(dqh2[1], y1[0], z1[0]); - dq[2] = __hfma2(dqh2[2], y1[0], z1[0]); - dq[3] = __hfma2(dqh2[3], y1[0], z1[0]); - } - else - { - dq[0] = __hadd2(dqh2[0], z1[0]); - dq[1] = __hadd2(dqh2[1], z1[0]); - dq[2] = __hadd2(dqh2[2], z1[0]); - dq[3] = __hadd2(dqh2[3], z1[0]); - } -} - -} // namespace gptq -} // namespace vllm - #endif diff --git a/csrc/quantization/gptq/qdq_8.cuh b/csrc/quantization/gptq/qdq_8.cuh new file mode 100644 index 0000000000000..b062ac3d64334 --- /dev/null +++ b/csrc/quantization/gptq/qdq_8.cuh @@ -0,0 +1,40 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_8_cuh +#define _qdq_8_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { + +__forceinline__ __device__ void shuffle_8bit_4 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_8bit_8 +( + const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], + int stride, + const uint32_t zero +) +{ + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +} // namespace gptq +} // namespace vllm + +#endif \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 8fe96e7ddb98d..a87e66bcf3006 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -29,9 +29,9 @@ def __init__( self.desc_act = desc_act self.pack_factor = 32 // self.weight_bits # exllama kernel v1 only supports 4 bit - if self.weight_bits != 4: + if self.weight_bits not in [2, 4, 8]: raise ValueError( - "Currently, only 4-bit weight quantization is supported for " + "Currently, only 2/4/8-bit weight quantization is supported for " f"GPTQ, but got {self.weight_bits} bits.") def __repr__(self) -> str: @@ -205,11 +205,13 @@ def apply_weights(self, else: weights["g_idx"] = torch.empty((1, 1), device="meta") weights["exllama_state"] = ExllamaState.READY - ops.gptq_shuffle(weights["qweight"], weights["g_idx"]) + ops.gptq_shuffle(weights["qweight"], weights["g_idx"], + self.quant_config.weight_bits) output = ops.gptq_gemm(reshaped_x, weights["qweight"], weights["qzeros"], weights["scales"], weights["g_idx"], - weights["exllama_state"] == ExllamaState.READY) + weights["exllama_state"] == ExllamaState.READY, + self.quant_config.weight_bits) if bias is not None: output = output + bias return output.reshape(out_shape) From 439c67608b74429b78b2a903bbef15dd768e7db4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Mon, 25 Dec 2023 10:32:41 +0800 Subject: [PATCH 2/5] Fix minor bug --- csrc/quantization/gptq/q_gemm.cu | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index f964d60278bef..8b09163bace14 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -654,7 +654,7 @@ void gemm_half_q_half_cuda_part gridDim.y = DIVIDE(size_m, m_count); gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); - fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); kernel<<>> ( @@ -667,7 +667,7 @@ void gemm_half_q_half_cuda_part size_n, size_k, groups, - b_q_perm, + b_q_perm ); } @@ -978,7 +978,6 @@ __global__ void reconstruct_exllama_2bit_kernel { for (int j = 0; j < 8; j++) { - printf("Debug %d %d %d\n", offset_k, lk, n); for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); @@ -1177,7 +1176,7 @@ __global__ void gemm_half_q_half_alt_2bit_kernel( int g_h = h * 16; int k = 0; int z_w = w / 16; - int z_mod = (w % 16) * 2 + int z_mod = (w % 16) * 2; half2 res2; half res[BLOCK_M_SIZE_MAX] = {}; @@ -1763,7 +1762,7 @@ void gptq_shuffle vllm::gptq::shuffle_exllama_weight( (uint32_t*) q_weight.data_ptr(), q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), - q_weight.size(0) * 8, + q_weight.size(0) * 32 / bit, q_weight.size(1), bit ); From 70f4c3094d4210f5f37854dfaaf7eab58dccae7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Mon, 25 Dec 2023 23:04:06 +0800 Subject: [PATCH 3/5] Add 3-bit quant --- csrc/quantization/gptq/matrix_view.cuh | 51 +++ csrc/quantization/gptq/q_gemm.cu | 410 +++++++++++++++++- csrc/quantization/gptq/qdq_3.cuh | 141 ++++++ .../layers/quantization/gptq.py | 10 +- 4 files changed, 603 insertions(+), 9 deletions(-) create mode 100644 csrc/quantization/gptq/qdq_3.cuh diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/quantization/gptq/matrix_view.cuh index e0a8c274e5fc6..eda3436eb5375 100644 --- a/csrc/quantization/gptq/matrix_view.cuh +++ b/csrc/quantization/gptq/matrix_view.cuh @@ -182,6 +182,57 @@ public: } }; +class MatrixView_q3_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int z_w = column * 3 / 32; + int z_mod = column & 0x1f; + + if (z_mod == 10) { + return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); + } else if (z_mod == 21) { + return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); + } else if (z_mod < 10) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; + } else if (z_mod < 21) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; + } else { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; + } + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const + { + int shift = (column & 0x1f); + uint32_t d; + if (shift <= 4) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); + } else if (shift == 8) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); + } else if (shift <= 16) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); + } else if (shift == 20) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); + } else { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); + } + items[0] = d & 0x07; + items[1] = (d >> 3) & 0x07; + items[2] = (d >> 6) & 0x07; + items[3] = (d >> 9) & 0x07; + } +}; + class MatrixView_q8_row { public: diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 8b09163bace14..e160024ebe564 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -14,6 +14,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq #include "compat.cuh" #include "matrix_view.cuh" #include "qdq_2.cuh" +#include "qdq_3.cuh" #include "qdq_4.cuh" #include "qdq_8.cuh" @@ -24,6 +25,7 @@ namespace gptq { #define BLOCK_M_SIZE_MAX 8 #define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) #define MAX_Q_GEMM_ROWS 50 +#define MAX_Q_GEMM_ROWS_8BIT 24 #define MAX_ALT_GEMM_ROWS 8 #define THREADS_X 32 #define THREADS_Y 32 @@ -465,6 +467,137 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel } } +template +__global__ void gemm_half_q_half_gptq_3bit_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / 32 * 3; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + } + + #pragma unroll + for (int j = 0; j < 1; j++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); + + #pragma unroll + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 32; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) + { + half2 *out = (half2*) c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + template __global__ void gemm_half_q_half_gptq_8bit_kernel ( @@ -599,6 +732,7 @@ fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel( #define SELECT_KERNEL(M_COUNT) \ if (m_count == M_COUNT) { \ if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ + if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ } @@ -886,6 +1020,108 @@ __global__ void reconstruct_exllama_4bit_kernel } } +__global__ void reconstruct_exllama_3bit_kernel +( + const uint32_t* __restrict__ b_q_weight, + const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + const int size_k, + const int size_n, + const int groups, + half* __restrict__ b +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) + { + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / 32* 3; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 1; p++) + { + int4 load_int4[3]; + load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; + load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); + + if (b_q_perm) + { + for (int j = 0; j < 16; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + else + { + for (int j = 0; j < 16; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + __global__ void reconstruct_exllama_2bit_kernel ( const uint32_t* __restrict__ b_q_weight, @@ -1010,8 +1246,9 @@ void reconstruct_exllama auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; if (bit == 2) { reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; - } - if (bit == 8) { + } else if (bit == 3) { + reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; + } else if (bit == 8) { reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; } @@ -1398,6 +1635,56 @@ __global__ void reconstruct_gptq_kernel } } +__global__ void reconstruct_gptq_3bit_kernel +( + const uint32_t* __restrict__ w, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int* __restrict__ g_idx, + const int height, + const int width, + const int group, + half* __restrict__ out +) +{ + // Start of block + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 32; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q3_row w_zeros_(w_zeros, group, width); + + uint32_t w1 = w[(blockIdx.y * 3) * width + column]; + uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; + uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int i = 0; i < 32; i += 1) + { + int group = g_idx[row + i]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + int w_item; + if (i == 10) { + w_item = (w1 >> 30) | ((w2 << 2) & 0x4); + } else if (i == 21) { + w_item = (w2 >> 31) | ((w3 << 1) & 0x6); + } else if (i < 10) { + w_item = ((w1 >> (i * 3)) & 0x7); + } else if (i < 21) { + w_item = ((w2 >> (i * 3 - 32)) & 0x7); + } else { + w_item = ((w3 >> (i * 3 - 64)) & 0x7); + } + *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); + out_ptr += out_.width; + } +} void reconstruct_gptq ( @@ -1423,6 +1710,9 @@ void reconstruct_gptq kernel = reconstruct_gptq_kernel; } else if (bit == 8) { kernel = reconstruct_gptq_kernel; + } else if (bit == 3) { + kernel = reconstruct_gptq_3bit_kernel; + gridDim.y = DIVIDE(height, 32); } kernel<<>> @@ -1457,7 +1747,14 @@ void gemm_half_q_half_cuda int bit ) { - if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) { + bool use_reconstruct; + if (use_exllama) { + use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + } else { + // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so we disabled them for now. + use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + } + if (use_reconstruct) { // Reconstruct FP16 matrix, then cuBLAS if (use_exllama) { reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, @@ -1550,6 +1847,20 @@ __global__ void shuffle_2bit_kernel while (k < size_k) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; } } +__global__ void shuffle_3bit_kernel +( + uint32_t* __restrict__ b_q_weight, + const int size_k, + const int size_n +) +{ + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; } +} + __global__ void make_sequential_4bit_kernel ( const uint32_t* __restrict__ w, @@ -1622,6 +1933,92 @@ __global__ void make_sequential_2bit_kernel w_new2[w_new2_row * w2_stride + w2_column] = dst; } +__global__ void make_sequential_3bit_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width +) +{ + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w_column >= w_width) return; + int w_new_row = blockIdx.y * 3; + int q_perm_idx = blockIdx.y << 5; + uint32_t dst[3] = {0, 0, 0}; + + #pragma unroll + for (int i = 0; i < 32; i++) + { + int source_row = q_perm[q_perm_idx++]; + int z_w = (source_row / 32) * 3; + int z_mod = source_row % 32; + int z_bit; + + if (z_mod != 10){ + if (z_mod != 21){ + z_bit = z_mod; + if (z_bit > 21){ + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10){ + z_bit *= 3; + z_bit -= 32; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + uint64_t src; + if (z_mod == 10) { + src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); + } else if (z_mod == 21){ + src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); + } else { + src = w[z_w * w_width + w_column]; + src >>= z_bit; + src &= 0x07; + } + + z_w = 0; + if (i != 10){ + if (i != 21){ + z_bit = i; + if (z_bit > 21){ + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10){ + z_bit *= 3; + z_bit -= 32; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + if (i == 10) { + dst[z_w] |= (src & 0x03) << 30; + dst[z_w + 1] |= ((src & 0x4) >> 2); + } else if (i == 21) { + dst[z_w] |= (src & 0x01) << 31; + dst[z_w + 1] |= ((src & 0x6) >> 1); + } else { + dst[z_w] |= (src << z_bit); + } + } + w_new[w_new_row * w_width + w_column] = dst[0]; + w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; + w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; +} + __global__ void make_sequential_8bit_kernel ( const uint32_t* __restrict__ w, @@ -1645,7 +2042,7 @@ __global__ void make_sequential_8bit_kernel int source_row = q_perm[q_perm_idx++]; int w2_row = source_row >> 2; - int w2_subrow = source_row & 0xff; + int w2_subrow = source_row & 0x03; int w2_row_shift = w2_subrow << 3; int wnew2_row_shift = i << 3; @@ -1682,6 +2079,9 @@ void shuffle_exllama_weight auto kernel = make_sequential_4bit_kernel; if (bit == 2) { kernel = make_sequential_2bit_kernel; + } else if (bit == 3) { + kernel = make_sequential_3bit_kernel; + gridDim.y = height / 32; } else if (bit == 8) { kernel = make_sequential_8bit_kernel; } @@ -1706,6 +2106,8 @@ void shuffle_exllama_weight auto shuffle_kernel = shuffle_4bit_kernel; if (bit == 2) { shuffle_kernel = shuffle_2bit_kernel; + } else if (bit == 3) { + shuffle_kernel = shuffle_3bit_kernel; } else if (bit == 8) { shuffle_kernel = shuffle_8bit_kernel; } diff --git a/csrc/quantization/gptq/qdq_3.cuh b/csrc/quantization/gptq/qdq_3.cuh new file mode 100644 index 0000000000000..e7bdef93cdc95 --- /dev/null +++ b/csrc/quantization/gptq/qdq_3.cuh @@ -0,0 +1,141 @@ +#ifndef _qdq_3_cuh +#define _qdq_3_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { +// Permutation: +// +// v9997775 55333111 u8886664 44222000 (u, v lsb) +// vjjjhhhf ffdddbbb uiiiggge eecccaaa +// vtttrrrp ppnnnlll usssqqqo oommmkkk + +__forceinline__ __device__ void shuffle_3bit_32 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } + for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32 +( + const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], + int stride, + const uint32_t zero +) +{ + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[ 0] = __hadd2( q0.as_half2, z1); + dq[ 1] = __hfma2( q1.as_half2, y8, z8); + dq[ 2] = __hadd2( q2.as_half2, z1); + dq[ 3] = __hfma2( q3.as_half2, y8, z8); + dq[ 4] = __hfma2( q4.as_half2, y64, z64); + dq[ 5] = __hadd2( q5.as_half2, z1); + dq[ 6] = __hfma2( q6.as_half2, y8, z8); + dq[ 7] = __hadd2( q7.as_half2, z1); + dq[ 8] = __hfma2( q8.as_half2, y8, z8); + dq[ 9] = __hfma2( q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); +} + +} // namespace gptq +} // namespace vllm + +#endif \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index a87e66bcf3006..119218cc2850c 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -1,6 +1,7 @@ import enum from enum import Enum from typing import Any, Dict, List, Optional +from fractions import Fraction import torch from torch.nn.parameter import Parameter @@ -27,11 +28,10 @@ def __init__( self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act - self.pack_factor = 32 // self.weight_bits - # exllama kernel v1 only supports 4 bit - if self.weight_bits not in [2, 4, 8]: + self.pack_factor = Fraction(32, self.weight_bits) + if self.weight_bits not in [2, 3, 4, 8]: raise ValueError( - "Currently, only 2/4/8-bit weight quantization is supported for " + "Currently, only 2/3/4/8-bit weight quantization is supported for " f"GPTQ, but got {self.weight_bits} bits.") def __repr__(self) -> str: @@ -101,7 +101,7 @@ def create_weights( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") - if output_size_per_partition % self.quant_config.pack_factor != 0: + if output_size_per_partition % self.quant_config.pack_factor.numerator != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " From aa2df9b92692322628d1275788cd3ceda8d6ff2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Wed, 3 Jan 2024 22:43:28 +0800 Subject: [PATCH 4/5] Fix CUDA Graph --- csrc/quantization/gptq/q_gemm.cu | 122 ++---------------- .../squeezellm/quant_cuda_kernel.cu | 3 +- vllm/config.py | 6 - 3 files changed, 15 insertions(+), 116 deletions(-) diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index e160024ebe564..9a7a63b04ad74 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -790,7 +790,8 @@ void gemm_half_q_half_cuda_part fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); - kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>> ( a, b_q_weight, @@ -1252,7 +1253,8 @@ void reconstruct_exllama reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; } - reconstruct_exllama_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_exllama_kernel<<>> ( b_q_weight, b_q_perm, @@ -1363,106 +1365,6 @@ __global__ void gemm_half_q_half_alt_4bit_kernel( } } -__global__ void gemm_half_q_half_alt_2bit_kernel( - const half2* __restrict__ vec, - const uint32_t* __restrict__ mat, - half* __restrict__ mul, - const half* __restrict__ scales, - const uint32_t* __restrict__ zeros, - const int* __restrict__ g_idx, - int batch, - int height, - int width -) -{ - int zero_width = width / 16; - int vec_height = height * 8; - const int blockwidth2 = BLOCK_KN_SIZE / 2; - int b = blockIdx.y * BLOCK_M_SIZE_MAX; - int b_end = min(BLOCK_M_SIZE_MAX, batch - b); - int h = BLOCK_KN_SIZE * blockIdx.z / 16; - int h_end = min(BLOCK_KN_SIZE / 16, height - h) * 8; - int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - - __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; - if (threadIdx.x < h_end) { - for (int m = 0; m < b_end; ++m) { - blockvec[m][threadIdx.x] = - vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + - threadIdx.x]; - } - } - - __shared__ half2 deq2[16][16]; - int val = threadIdx.x / 16; - int off = threadIdx.x % 16; - for (; val < 16; val += BLOCK_KN_SIZE / 16) { - deq2[val][off] = __halves2half2( - __int2half_rn(val & 0x3), __int2half_rn(val >> 2) - ); - } - - if (blockIdx.z == 0) - { - for (int m = 0; m < b_end; m++) - mul[(b + m) * width + w] = __int2half_rn(0); - } - __syncthreads(); - - int i = width * h + w; - int g_h = h * 16; - int k = 0; - int z_w = w / 16; - int z_mod = (w % 16) * 2; - half2 res2; - half res[BLOCK_M_SIZE_MAX] = {}; - - unsigned int tmp; - while (k < h_end) { - tmp = mat[i]; - half2 scales_tmp[8]; - half2 zeros_tmp[8]; - for (int tmp_k = 0; tmp_k < 8; tmp_k++) { - int g = g_idx[g_h + (k + tmp_k) * 2]; - int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; - half scale_f = scales[g * width + w]; - half scale_f2 = scales[g2 * width + w]; - half2 scale = __halves2half2(scale_f, scale_f2); - half2 zero = __halves2half2( - __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0x03) - 1)), - __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0x03) - 1)) - ); - scales_tmp[tmp_k] = scale; - zeros_tmp[tmp_k] = zero; - } - for (int m = 0; m < b_end; m++) { -#ifndef USE_ROCM - res2 = {}; -#else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); -#endif - res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 4) & 0xf][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xf][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 12) & 0xf][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xf][off], scales_tmp[4], zeros_tmp[4]), blockvec[m][k + 4], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 20) & 0xf][off], scales_tmp[5], zeros_tmp[5]), blockvec[m][k + 5], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xf][off], scales_tmp[6], zeros_tmp[6]), blockvec[m][k + 6], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 28) & 0xf][off], scales_tmp[7], zeros_tmp[7]), blockvec[m][k + 7], res2); -#ifndef USE_ROCM - res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); -#else - res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); -#endif - } - i += width; - k += 8; - } - for (int m = 0; m < b_end; m++) { - atomicAdd(&mul[(b + m) * width + w], res[m]); - } -} __global__ void gemm_half_q_half_alt_8bit_kernel( const half2* __restrict__ vec, @@ -1576,13 +1478,12 @@ void gemm_half_q_half_alt gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); auto kernel = gemm_half_q_half_alt_4bit_kernel; - if (bit == 2) { - kernel = gemm_half_q_half_alt_2bit_kernel; - } else if (bit == 8) { + if (bit == 8) { kernel = gemm_half_q_half_alt_8bit_kernel; } - kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>> ( (const half2*) a, b_q_weight, @@ -1715,7 +1616,8 @@ void reconstruct_gptq gridDim.y = DIVIDE(height, 32); } - kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>> ( b_q_weight, b_gptq_scales, @@ -2085,7 +1987,8 @@ void shuffle_exllama_weight } else if (bit == 8) { kernel = make_sequential_8bit_kernel; } - kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>> ( q_weight, new_qweight, @@ -2111,7 +2014,8 @@ void shuffle_exllama_weight } else if (bit == 8) { shuffle_kernel = shuffle_8bit_kernel; } - shuffle_kernel<<>>(q_weight, height, width); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + shuffle_kernel<<>>(q_weight, height, width); } } // namespace gptq diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index b17ced6fce79b..6b6492f1cdeda 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -201,7 +201,8 @@ void squeezellm_gemm( ); dim3 threads(BLOCKWIDTH); const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); - vllm::squeezellm::NUQ4MatMulKernel<<>>( + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::squeezellm::NUQ4MatMulKernel<<>>( #ifndef USE_ROCM (half2*) vec.data(), #else diff --git a/vllm/config.py b/vllm/config.py index ff9a1308a5c88..f1efcc66e9097 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -181,12 +181,6 @@ def _verify_cuda_graph(self) -> None: self.max_context_len_to_capture = self.max_model_len self.max_context_len_to_capture = min(self.max_context_len_to_capture, self.max_model_len) - if (self.quantization in ["gptq", "squeezellm"] - and not self.enforce_eager): - # Related issue: https://github.com/vllm-project/vllm/issues/2147 - logger.warning(f"{self.quantization} does not support CUDA graph " - "yet. Disabling CUDA graph.") - self.enforce_eager = True def verify_with_parallel_config( self, From 0a94bdda2645716859fa756721b2ed1ef991f361 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Wed, 3 Jan 2024 22:49:13 +0800 Subject: [PATCH 5/5] Fix style --- csrc/quantization/gptq/qdq_2.cuh | 2 +- csrc/quantization/gptq/qdq_3.cuh | 2 +- csrc/quantization/gptq/qdq_8.cuh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/quantization/gptq/qdq_2.cuh b/csrc/quantization/gptq/qdq_2.cuh index fd78670ea1f10..295872a91de37 100644 --- a/csrc/quantization/gptq/qdq_2.cuh +++ b/csrc/quantization/gptq/qdq_2.cuh @@ -84,4 +84,4 @@ __forceinline__ __device__ void dequant_2bit_16 } // namespace gptq } // namespace vllm -#endif \ No newline at end of file +#endif diff --git a/csrc/quantization/gptq/qdq_3.cuh b/csrc/quantization/gptq/qdq_3.cuh index e7bdef93cdc95..3e7ecde752ba3 100644 --- a/csrc/quantization/gptq/qdq_3.cuh +++ b/csrc/quantization/gptq/qdq_3.cuh @@ -138,4 +138,4 @@ __forceinline__ __device__ void dequant_3bit_32 } // namespace gptq } // namespace vllm -#endif \ No newline at end of file +#endif diff --git a/csrc/quantization/gptq/qdq_8.cuh b/csrc/quantization/gptq/qdq_8.cuh index b062ac3d64334..0c7ad7876140b 100644 --- a/csrc/quantization/gptq/qdq_8.cuh +++ b/csrc/quantization/gptq/qdq_8.cuh @@ -37,4 +37,4 @@ __forceinline__ __device__ void dequant_8bit_8 } // namespace gptq } // namespace vllm -#endif \ No newline at end of file +#endif