From 50732c06981ccbce208e46e45ddc290f37ec8775 Mon Sep 17 00:00:00 2001 From: Luca Date: Wed, 26 Jun 2024 15:32:52 +0200 Subject: [PATCH] Add 128 byte read and 128 threads per block --- ggml-cuda/dmmv.cu | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 174489e0665d3..f49a771d60241 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -217,10 +217,21 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, float4 s = {0.f, 0.f, 0.f, 0.f}; float smin = 0; + + float4 y11 = *reinterpret_cast(y1+0); + float4 y12 = *reinterpret_cast(y1+32); + float4 y21 = *reinterpret_cast(y2+0); + float4 y22 = *reinterpret_cast(y2+32); + + const float* p11 = &y11.x; + const float* p12 = &y12.x; + const float* p21 = &y21.x; + const float* p22 = &y22.x; + for (int l = 0; l < 4; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4]; - s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + s.x += p11[l] * q4[l+0]; s.y += p12[l] * q4[l+ 4]; + s.z += p21[l] * q4[l+8]; s.w += p22[l] * q4[l+12]; + smin += p11[l] * sc[2] + p12[l] * sc[3] + p21[l] * sc[6] + p22[l] * sc[7]; } tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; #else @@ -563,12 +574,15 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); } +#define BLOCK_DIM_X 32 +#define BLOCK_DIM_Y 4 static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 2*BLOCK_DIM_Y / K_QUANTS_PER_ITERATION; + constexpr int grid_scale = BLOCK_DIM_X/32; const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); + const dim3 block_nums((block_num_y+grid_scale-1)/grid_scale, 1, 1); + const dim3 block_dims(BLOCK_DIM_X, ny, 1); dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); }