From cf0a50519b2aebe3984db72dac43e5cee28a8eb5 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sun, 16 Jul 2023 11:16:15 +0200 Subject: [PATCH] q8_0 works --- ggml-cuda.cu | 78 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 59 insertions(+), 19 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c75b5f36b74e4e..b22db4929844fd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1636,6 +1636,37 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( return vec_dot_q8_0_q8_1_impl(vi, ui, bq8_0->d, bq8_1->ds); } +static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) { + + __shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)]; + __shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI8_0)]; + + *x_ql = tile_x_qs; + *x_dm = tile_x_d; +} + +static __device__ __forceinline__ void load_tiles_q8_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) { + + const int kbx = k / QI8_0; + const int kqsx = k % QI8_0; + + const block_q8_0 * bx = ((block_q8_0 *) vx) + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bx->qs, kqsx); + x_dm[i * (WARP_SIZE / QI8_0) + kbx].x = bx->d; +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + return vec_dot_q8_0_q8_1_impl( + x_ql[i * (WARP_SIZE + 1) + k], y_qs[j*WARP_SIZE + k], + x_dm[i * (WARP_SIZE/QI8_0) + k/QI8_0].x, y_ds[j * (WARP_SIZE/QI8_1) + k/QI8_1]); +} + static __device__ __forceinline__ float vec_dot_q2_K_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { @@ -1849,7 +1880,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( #endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -template static __global__ void mul_mat_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, @@ -1880,8 +1911,8 @@ static __global__ void mul_mat_q( allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); - __shared__ int tile_y_qs[(WARP_SIZE) * (2*WARP_SIZE)]; - __shared__ half2 tile_y_ds[(WARP_SIZE) * (2*WARP_SIZE/QI8_1)]; + __shared__ int tile_y_qs[(WARP_SIZE) * (qr*WARP_SIZE)]; + __shared__ half2 tile_y_ds[(WARP_SIZE) * (qr*WARP_SIZE/QI8_1)]; float sum[2][4] = {0.0f}; @@ -1892,22 +1923,20 @@ static __global__ void mul_mat_q( i + tid_y, tid_x, blocks_per_row); } - const int iby0 = tid_x / QI8_1; - const int iby1 = iby0 + WARP_SIZE / QI8_1; const int iqsy = sizeof(int) * (tid_x % QI8_1); - for (int i = 0; i < WARP_SIZE; i += 8) { - const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses + for (int ir = 0; ir < qr; ++ir) { + const int kqs = ir*WARP_SIZE + tid_x; + const int kby = kqs / QI8_1; - const block_q8_1 * __restrict__ by0 = &y[col_y_eff*blocks_per_row + ib0 + iby0]; + for (int i = 0; i < WARP_SIZE; i += 8) { + const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses - tile_y_qs[(tid_y + i) * (2*WARP_SIZE) + tid_x] = *((int *) &by0->qs[iqsy]); - tile_y_ds[(tid_y + i) * (2*WARP_SIZE/QI8_1) + iby0] = by0->ds; + const block_q8_1 * by0 = &y[col_y_eff*blocks_per_row + ib0 + kby]; - const block_q8_1 * __restrict__ by1 = &y[col_y_eff*blocks_per_row + ib0 + iby1]; - - tile_y_qs[(tid_y + i) * (2*WARP_SIZE) + tid_x + WARP_SIZE] = *((int *) &by1->qs[iqsy]); - tile_y_ds[(tid_y + i) * (2*WARP_SIZE/QI8_1) + iby1] = by1->ds; + tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = *((int *) &by0->qs[iqsy]); + tile_y_ds[(tid_y + i) * (qr*WARP_SIZE/QI8_1) + kby] = by0->ds; + } } __syncthreads(); @@ -2633,7 +2662,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(const void * vx, const void * vy, float const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); - mul_mat_q<<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst); + mul_mat_q<<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst); } static void ggml_mul_mat_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ @@ -2641,7 +2670,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(const void * vx, const void * vy, float const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); - mul_mat_q<<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst); + mul_mat_q<<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst); } static void ggml_mul_mat_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ @@ -2649,7 +2678,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(const void * vx, const void * vy, float const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); - mul_mat_q<<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst); + mul_mat_q<<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst); } static void ggml_mul_mat_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ @@ -2657,7 +2686,15 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(const void * vx, const void * vy, float const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); - mul_mat_q<<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst); + mul_mat_q<<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst); +} + +static void ggml_mul_mat_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ + const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); + mul_mat_q<<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst); } static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { @@ -3123,6 +3160,9 @@ inline void ggml_cuda_op_mul_mat_q( case GGML_TYPE_Q5_1: ggml_mul_mat_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); break; + case GGML_TYPE_Q8_0: + ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main); + break; default: GGML_ASSERT(false); break; @@ -3873,7 +3913,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false); } else { if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || - src0->type == GGML_TYPE_Q5_1) { + src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false); } else { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);