Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding IQ6_K #14

Merged
merged 10 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", },
{ "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", },
{ "IQ5_K", LLAMA_FTYPE_MOSTLY_IQ5_K, " 5.5 bpw non-linear quantization", },
{ "IQ6_K", LLAMA_FTYPE_MOSTLY_IQ6_K, " 6.6 bpw non-linear quantization", },
{ "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", },
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", },
Expand Down
6 changes: 4 additions & 2 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ extern "C" {
GGML_TYPE_IQ3_K = 38,
GGML_TYPE_IQ4_K = 39,
GGML_TYPE_IQ5_K = 40,
GGML_TYPE_IQ2_TN = 41,
GGML_TYPE_IQ6_K = 41,
GGML_TYPE_IQ2_TN = 42,
GGML_TYPE_COUNT,
};

Expand Down Expand Up @@ -444,7 +445,8 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ3_K = 31, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_K = 32, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ5_K = 33, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_TN = 34, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ6_K = 34, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_TN = 35, // except 1d tensors
};

// available tensor operations:
Expand Down
22 changes: 22 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ typedef sycl::half2 ggml_half2;
#define QI5_XS (QK_K / (4*QR5_XS))
#define QR5_XS 2

#define QI6_XS (QK_K / (4*QR6_XS))
#define QR6_XS 2

#define QI3_S (QK_K / (4*QR3_S))
#define QR3_S 4

Expand Down Expand Up @@ -493,6 +496,15 @@ typedef struct {
} block_iq5_k;
static_assert(sizeof(block_iq5_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/8 + 3*QK_K/64, "wrong iq5_k block size/padding");

typedef struct {
ggml_half d;
uint16_t extra;
int8_t scales[QK_K/16];
uint8_t qs[QK_K/2];
uint8_t qh[QK_K/4];
} block_iq6_k;
static_assert(sizeof(block_iq6_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/4 + QK_K/16, "wrong iq6_k block size/padding");


#endif // GGML_COMMON_DECL
#endif // GGML_COMMON_DECL
Expand Down Expand Up @@ -1944,6 +1956,16 @@ GGML_TABLE_BEGIN(int8_t, iq5nl_values, 64)
-124, -112, -101, -90, -81, -72, -63, -55, -48, -41, -34, -28, -22, -16, -10, -4, 1, 7, 13, 19, 25, 31, 38, 45, 53, 61, 70, 79, 89, 99, 111, 123,
GGML_TABLE_END()

GGML_TABLE_BEGIN(int8_t, iq6nl_values, 128)
-127, -121, -115, -109, -104, -98, -93, -88, -84, -79, -74, -70, -66, -62, -58, -54,
-51, -47, -44, -40, -37, -34, -31, -28, -25, -22, -19, -16, -13, -11, -8, -5,
-2, 0, 3, 6, 9, 12, 14, 17, 20, 23, 27, 30, 33, 36, 40, 44,
47, 51, 55, 59, 63, 68, 72, 77, 82, 87, 92, 98, 103, 109, 115, 121,
-126, -120, -114, -108, -103, -97, -92, -87, -83, -78, -73, -69, -65, -61, -57, -53,
-50, -46, -43, -39, -36, -33, -30, -27, -24, -21, -18, -15, -12, -10, -7, -4,
-1, 1, 4, 7, 10, 13, 15, 18, 21, 24, 28, 31, 34, 37, 41, 45,
48, 52, 56, 60, 64, 69, 73, 78, 83, 88, 93, 99, 104, 110, 116, 122,
GGML_TABLE_END()

#endif // GGML_COMMON_IMPL
#endif // GGML_COMMON_IMPL
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2757,6 +2757,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_TN:
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> {
static constexpr int qi = QI5_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ6_K> {
static constexpr int qk = QK_K;
static constexpr int qr = QR6_XS;
static constexpr int qi = QI6_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
static constexpr int qk = QK_K;
Expand Down
41 changes: 41 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,37 @@ static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {

const int i = blockIdx.x;
const block_iq6_k * x = (const block_iq6_k *) vx;

const int tid = threadIdx.x;
int ib64 = tid/8; // 0...3
int il = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 64*ib64 + 2*il;
const float d = (float)x[i].d;
const float dl1 = d * x[i].scales[4*ib64+0];
const float dl2 = d * x[i].scales[4*ib64+1];
const float dl3 = d * x[i].scales[4*ib64+2];
const float dl4 = d * x[i].scales[4*ib64+3];
const uint8_t * qs = x[i].qs + 32*ib64 + 2*il;
const uint8_t * qh = x[i].qh + 32*(ib64/2) + 2*il;
const uint8_t extra = x[i].extra >> 4*(ib64%4);
for (int j = 0; j < 2; ++j) {
const uint8_t h1 = qh[j] >> 4*(ib64%2), h2 = qh[j+16] >> 4*(ib64%2);
uint8_t q1 = (qs[j+ 0] & 0xf) | ((h1 & 0x03) << 4);
uint8_t q2 = (qs[j+16] & 0xf) | ((h2 & 0x03) << 4);
uint8_t q3 = (qs[j+ 0] >> 4) | ((h1 & 0x0c) << 2);
uint8_t q4 = (qs[j+16] >> 4) | ((h2 & 0x0c) << 2);
y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0));
y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0));
y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0));
y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0));
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {

Expand Down Expand Up @@ -803,6 +834,12 @@ static void dequantize_row_iq5_k_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_iq5_k<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq6_k_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq6_k<<<nb, 32, 0, stream>>>(vx, y);
}

template <typename src_t, typename dst_t>
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
Expand Down Expand Up @@ -877,6 +914,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq4_k_cuda;
case GGML_TYPE_IQ5_K:
return dequantize_row_iq5_k_cuda;
case GGML_TYPE_IQ6_K:
return dequantize_row_iq6_k_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F32:
Expand Down Expand Up @@ -938,6 +977,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq4_k_cuda;
case GGML_TYPE_IQ5_K:
return dequantize_row_iq5_k_cuda;
case GGML_TYPE_IQ6_K:
return dequantize_row_iq6_k_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F16:
Expand Down
45 changes: 44 additions & 1 deletion ggml/src/ggml-cuda/iqk_mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,43 @@ __device__ __forceinline__ float vec_dot_iq5_k_q8_1(
return d5 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * ls1 + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * ls2);
}

#define VDR_IQ6_K_Q8_1_MMVQ 4
#define VDR_IQ6_K_Q8_1_MMQ 4

__device__ __forceinline__ float vec_dot_iq6_k_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {


const block_iq6_k * bq6 = (const block_iq6_k *) vbq + kbx;
const uint8_t * all_values = (const uint8_t *)iq6nl_values;

int i4 = iqs/4; // 0...7. Blocks of 16 index is 4*(i4/2) + (i4%2) + (0 and 2)
// Blocks of 32 index is 2*(i4/2) + 0 or 1

const int32_t * q8_1 = (const int *)bq8_1[2*(i4/2)+0].qs + 4*(i4%2);
const int32_t * q8_2 = (const int *)bq8_1[2*(i4/2)+1].qs + 4*(i4%2);
const uint32_t * q4 = (const uint32_t *)bq6->qs + 8*(i4/2) + 4*(i4%2);
const uint32_t * qh = (const uint32_t *)bq6->qh + 8*(i4/4) + 4*(i4%2);
const uint16_t extra = bq6->extra >> (4*(i4/2) + (i4%2));
const uint8_t * values1 = all_values + 64*(extra & 1);
const uint8_t * values2 = all_values + 16*(extra & 4);
uint32_t aux32[2];
const uint8_t * a8 = (const uint8_t *)aux32;
int v1, v2;
int sumi1 = 0, sumi2 = 0;
for (int j = 0; j < 4; ++j) {
uint32_t h = qh[j] >> 4*((i4/2)%2);
aux32[0] = ((q4[j] >> 0) & 0x0f0f0f0f) | ((h << 4) & 0x30303030);
aux32[1] = ((q4[j] >> 4) & 0x0f0f0f0f) | ((h << 2) & 0x30303030);
v1 = int_from_table(a8+0, values1);
v2 = int_from_table(a8+4, values2);
sumi1 = ggml_cuda_dp4a(v1, q8_1[j], sumi1);
sumi2 = ggml_cuda_dp4a(v2, q8_2[j], sumi2);
}
const float d6 = __half2float(bq6->d);
return d6 * (__low2float(bq8_1[2*(i4/2)+0].ds) * sumi1 * bq6->scales[4*(i4/2)+(i4%2)] + __low2float(bq8_1[2*(i4/2)+1].ds) * sumi2 * bq6->scales[4*(i4/2)+(i4%2)+2]);
}

static const __device__ uint32_t iq2k_table[512] = {
0xe1e1e1e1, 0xe1e1e1f3, 0xe1e1e101, 0xe1e1e111, 0xe1e1f3e1, 0xe1e1f3f3, 0xe1e1f301, 0xe1e1f311,
0xe1e101e1, 0xe1e101f3, 0xe1e10101, 0xe1e10111, 0xe1e111e1, 0xe1e111f3, 0xe1e11101, 0xe1e11111,
Expand Down Expand Up @@ -534,10 +571,16 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ5_K, VDR_IQ5_K_Q8_1_MMVQ, vec_dot_iq5_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

void mul_mat_vec_iq6_k_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ6_K, VDR_IQ6_K_Q8_1_MMVQ, vec_dot_iq6_k_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

void mul_mat_vec_iq2_tn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ2_TN, VDR_IQ2_TN_Q8_1_MMVQ, vec_dot_iq2_tn_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}

4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/iqk_mmvq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ void mul_mat_vec_iq5_k_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

void mul_mat_vec_iq6_k_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

void mul_mat_vec_iq2_tn_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,9 @@ void ggml_cuda_op_mul_mat_vec_q(
case GGML_TYPE_IQ5_K:
mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ6_K:
mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
Expand Down
Loading