From 31f229c736b8ca88cca68a2fbb65b77e829ed013 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 14 Jul 2023 17:34:08 +0200 Subject: [PATCH] larger x tiles --- Makefile | 2 +- ggml-cuda.cu | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/Makefile b/Makefile index 0a4f0640dd3d4b..c53928cefa6213 100644 --- a/Makefile +++ b/Makefile @@ -169,7 +169,7 @@ ifdef LLAMA_CUBLAS LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib OBJS += ggml-cuda.o NVCC = nvcc - NVCCFLAGS = --forward-unknown-to-host-compiler + NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math ifdef CUDA_DOCKER_ARCH NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH) else diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0e17b2ac26d3c4..1b6a3fa8237df0 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1662,24 +1662,24 @@ static __global__ void mul_mat_q( const int tid_x = threadIdx.x; const int tid_y = threadIdx.y; - const int row_dst_0 = blockIdx.x*WARP_SIZE; + const int row_dst_0 = 2*blockIdx.x*WARP_SIZE; const int & row_x_0 = row_dst_0; const int row_dst = row_dst_0 + tid_x; const int col_dst_0 = blockIdx.y*WARP_SIZE; const int & col_y_0 = col_dst_0; - __shared__ int tile_x_qs[WARP_SIZE][WARP_SIZE + 1]; - __shared__ half tile_x_d[WARP_SIZE][WARP_SIZE/QI4_0]; + __shared__ int tile_x_qs[2*WARP_SIZE][WARP_SIZE + 1]; + __shared__ half tile_x_d[2*WARP_SIZE][WARP_SIZE/QI4_0]; __shared__ int tile_y_qs[WARP_SIZE][2*WARP_SIZE]; __shared__ half2 tile_y_ds[WARP_SIZE][2*WARP_SIZE/QI8_1]; - float sum[4] = {0.0f}; + float sum[2][4] = {0.0f}; for (int ib0 = 0; ib0 < blocks_per_row; ib0 += blocks_per_warp) { const int ibx = tid_x / QI4_0; const int iqsx = sizeof(int) * (tid_x % QI4_0); - for (int j = 0; j < WARP_SIZE; j += 8) { + for (int j = 0; j < 2*WARP_SIZE; j += 8) { const block_q4_0 * __restrict__ bx = &x[(row_x_0 + j + tid_y)*blocks_per_row + ib0 + ibx]; memcpy(&tile_x_qs[j + tid_y][tid_x], &bx->qs[iqsx], sizeof(int)); tile_x_d[j + tid_y][ibx] = bx->d; @@ -1706,9 +1706,12 @@ static __global__ void mul_mat_q( for (int k = 0; k < WARP_SIZE; ++k) { const int iqsy = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); for (int j = 0; j < WARP_SIZE; j += 8) { - sum[j/8] += vec_dot_q4_0_q8_1_impl( + sum[0][j/8] += vec_dot_q4_0_q8_1_impl( tile_x_qs[tid_x][k], tile_y_qs[tid_y + j][iqsy + 0], tile_y_qs[tid_y + j][iqsy + (QI8_1/2)], tile_x_d[tid_x][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]); + sum[1][j/8] += vec_dot_q4_0_q8_1_impl( + tile_x_qs[tid_x + WARP_SIZE][k], tile_y_qs[tid_y + j][iqsy + 0], tile_y_qs[tid_y + j][iqsy + (QI8_1/2)], + tile_x_d[tid_x + WARP_SIZE][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]); } } @@ -1727,7 +1730,8 @@ static __global__ void mul_mat_q( return; } - dst[col_dst*nrows_dst + row_dst] = sum[j/8]; + dst[col_dst*nrows_dst + row_dst] = sum[0][j/8]; + dst[col_dst*nrows_dst + row_dst + WARP_SIZE] = sum[1][j/8]; } } @@ -2417,7 +2421,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } static void ggml_mul_mat_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){ - const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE; + const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE); const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(block_num_x, block_num_y, 1); const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);