Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Q4_0_R4 on CUDA #127

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2853,6 +2853,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_Q4_0_R4:
return true;
default:
return false;
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qi = QI3_S;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0_R4> {
static constexpr int qk = QK4_0;
static constexpr int qr = QR4_0;
static constexpr int qi = QI4_0;
};

//////////////////////

struct ggml_cuda_device_info {
Expand Down
40 changes: 40 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,35 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
#endif // __CUDA_ARCH__ >= CC_PASCAL
}

template<typename dst_t>
static __global__ void dequantize_block_q4_0_r4(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row) {

const int64_t ii = blockIdx.x;
int row4 = (256*ii)/(4*n_per_row);
const int64_t i = ii - row4*n_per_row/64;

// assume 32 threads
const int tid = threadIdx.x;
int is = tid/16; // 0 or 1: 1st or 2nd block of 128
int j = tid%16; // 0...15: index inside the block of 128
int l = j/4; // 0....3: index inside a q4_0 block
int k = j%4; // 0....3: row index in the group of 4 rows
int ll = 16*(l%2) + 4*(l/2);

dst_t * y = yy + (4*row4 + k)*n_per_row + 32*(2*i+is) + ll;

const block_iq4_nl_r4 * x = (const block_iq4_nl_r4 *)vx + 2*ii + is;
const float d = __half2float(x->d[k]);
const float dm = -8*d;

const uint8_t * q = x->qs + 16*l + 4*k;

for (int n = 0; n < 4; ++n) {
y[n+0] = d * (q[n] & 0xF) + dm;
y[n+8] = d * (q[n] >> 4) + dm;
}
}

template<typename dst_t>
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {

Expand Down Expand Up @@ -818,6 +847,13 @@ static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t n
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
}

template<typename dst_t>
static void dequantize_row_q4_0_r4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = (k + 255) / 256;
dequantize_block_q4_0_r4<<<nb, 32, 0, stream>>>(vx, y, n_per_row);
}

template<typename dst_t>
static void dequantize_row_q6_0_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
Expand Down Expand Up @@ -1073,6 +1109,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_0_R4:
return dequantize_row_q4_0_r4_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -1147,6 +1185,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_0_R4:
return dequantize_row_q4_0_r4_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0:
Expand Down
198 changes: 186 additions & 12 deletions ggml/src/ggml-cuda/iqk_mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
// constexpr int vdr = get_vdr_mmvq(type);

namespace {
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y>
template <int qk, int qi, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
Expand All @@ -23,9 +23,6 @@ __global__ void iqk_mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) {

constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;

#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
Expand Down Expand Up @@ -137,37 +134,208 @@ void iqk_mul_mat_vec_q_cuda(

const int64_t row_size = ggml_row_size(type, ncols_x);

constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;

switch (ncols_y) {
case 1:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 2:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 3:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 4:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 5:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 6:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 7:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 8:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
iqk_mul_mat_vec_q<qk, qi, vdr, vec_dot_q_cuda, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
default:
GGML_ASSERT(false);
break;
}
}

using block_q4_0_r4 = block_iq4_nl_r4;


__device__ __forceinline__ float vec_dot_q4_0_r4_q8_1_x(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ y, const int & kbx, const int & l, float * __restrict__ result) {

const block_q4_0_r4 * x = (block_q4_0_r4 *)vbq + kbx;
const half2 * d4h = (const half2 *)x->d;
float2 d4[2];
const float * d = (const float *)d4;
d4[0] = __half22float2(d4h[0]);
d4[1] = __half22float2(d4h[1]);
const float2 d8 = __half22float2(y->ds);

const int * q8 = (const int *)y->qs + 4*(l%2) + l/2;
const int * q4 = (const int *)x->qs + 4*l;

#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
for (int k = 0; k < 4; ++k) {
int v1 = q4[k] & 0x0f0f0f0f;
int v2 = (q4[k] >> 4) & 0x0f0f0f0f;
int dot = __dp4a(v1, q8[0], __dp4a(v2, q8[2], 0));
result[k] += d[k]*(d8.x*dot - 2.f*d8.y);
}
#else
NO_DEVICE_CODE;
#endif
}


template <int ncols_y>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__global__ void iqk_mul_mat_vec_q4_0_r4(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) {

// constexpr int nwarps = 1;

#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int nwarps = 1;
#else
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)

const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int row0 = 4*blockIdx.x;
const int blocks_per_row_x = ncols_x / 32;
const int blocks_per_col_y = nrows_y / 32;
constexpr int blocks_per_iter = nwarps*WARP_SIZE/4;

// partial sum for each thread
float tmp[ncols_y][4] = {0.0f};

const block_q8_1 * y = (const block_q8_1 *) vy;

for (int kbx = tid/4; kbx < blocks_per_row_x; kbx += blocks_per_iter) {

#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
vec_dot_q4_0_r4_q8_1_x((const void *)((const char *)vx + row0*row_size),
&y[j*blocks_per_col_y + kbx], kbx, tid%4, tmp[j]);
}
}

__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][4][WARP_SIZE];
if (threadIdx.y > 0) {
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
}
}
}
__syncthreads();
if (threadIdx.y > 0) {
return;
}

// sum up partial sums and write back result
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
#pragma unroll
for (int i = 0; i < 4; ++i) {
#pragma unroll
for (int l = 0; l < nwarps-1; ++l) {
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
}
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
}

if (threadIdx.x < 4 && (row0 + threadIdx.x < nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
}
}
}

static void iqk_mul_mat_vec_q4_0_r4_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

GGML_ASSERT(ncols_x % 32 == 0);
GGML_ASSERT(nrows_x % 4 == 0);

int id = ggml_cuda_get_device();

int nwarps = 1;

if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
switch(ncols_y) {
case 1:
case 2:
case 3:
case 4:
nwarps = 4;
break;
case 5:
case 6:
case 7:
case 8:
nwarps = 2;
break;
default:
GGML_ASSERT(false);
break;
}
}
const int64_t nblocks = nrows_x/4;
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);

const int64_t row_size = ggml_row_size(GGML_TYPE_Q4_0_R4, ncols_x);

switch (ncols_y) {
case 1:
iqk_mul_mat_vec_q4_0_r4<1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 2:
iqk_mul_mat_vec_q4_0_r4<2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 3:
iqk_mul_mat_vec_q4_0_r4<3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 4:
iqk_mul_mat_vec_q4_0_r4<4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 5:
iqk_mul_mat_vec_q4_0_r4<5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 6:
iqk_mul_mat_vec_q4_0_r4<6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 7:
iqk_mul_mat_vec_q4_0_r4<7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
case 8:
iqk_mul_mat_vec_q4_0_r4<8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
break;
default:
GGML_ASSERT(false);
break;
}
}


__device__ __forceinline__ void get_int_from_table_16_shift(const uint32_t & q4, uint16_t shift, const uint8_t * all_values,
int & val1, int & val2) {

Expand Down Expand Up @@ -728,6 +896,12 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(

} // namespace

void mul_mat_vec_q4_0_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
iqk_mul_mat_vec_q4_0_r4_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

void mul_mat_vec_iq2_k_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ void mul_mat_vec_iq1_bn_q8_1_cuda(
void mul_mat_vec_iq2_bn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

void mul_mat_vec_q4_0_r4_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ3_S:
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q4_0_R4:
mul_mat_vec_q4_0_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
default:
GGML_ABORT("fatal error");
break;
Expand Down