-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml-cuda : add TQ2_0 kernels, for ternary inference on GPU #11183
base: master
Are you sure you want to change the base?
Changes from 4 commits
970b5ab
fb43d5e
983aa09
f5fddb6
946796f
b6fc9f0
fbddb26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -63,6 +63,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { | |||||||||||||||||||||||||||||||||
case GGML_TYPE_Q5_K: | ||||||||||||||||||||||||||||||||||
return MMQ_Q8_1_DS_LAYOUT_DS4; | ||||||||||||||||||||||||||||||||||
case GGML_TYPE_Q6_K: | ||||||||||||||||||||||||||||||||||
case GGML_TYPE_TQ2_0: | ||||||||||||||||||||||||||||||||||
case GGML_TYPE_IQ2_XXS: | ||||||||||||||||||||||||||||||||||
case GGML_TYPE_IQ2_XS: | ||||||||||||||||||||||||||||||||||
case GGML_TYPE_IQ2_S: | ||||||||||||||||||||||||||||||||||
|
@@ -161,6 +162,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml | |||||||||||||||||||||||||||||||||
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_TQ2_0 ? MMQ_DP4A_TXS_Q8_0 : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 : | ||||||||||||||||||||||||||||||||||
|
@@ -195,6 +197,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { | |||||||||||||||||||||||||||||||||
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_TQ2_0 ? MMQ_MMA_TILE_X_K_Q8_0 : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K : | ||||||||||||||||||||||||||||||||||
type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K : | ||||||||||||||||||||||||||||||||||
|
@@ -1808,6 +1811,68 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( | |||||||||||||||||||||||||||||||||
#endif // INT8_MMA_AVAILABLE | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_tq2_0( | ||||||||||||||||||||||||||||||||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
#ifdef INT8_MMA_AVAILABLE | ||||||||||||||||||||||||||||||||||
int * x_qs = (int *) x_tile; | ||||||||||||||||||||||||||||||||||
float * x_df = (float *) (x_tile + 2*WARP_SIZE); | ||||||||||||||||||||||||||||||||||
#else | ||||||||||||||||||||||||||||||||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y); | ||||||||||||||||||||||||||||||||||
int * x_qs = (int *) x_tile; | ||||||||||||||||||||||||||||||||||
float * x_df = (float *) (x_qs + txs.qs); | ||||||||||||||||||||||||||||||||||
#endif // INT8_MMA_AVAILABLE | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
const int kqsx = threadIdx.x % QI2_0; | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
#pragma unroll | ||||||||||||||||||||||||||||||||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_0) { | ||||||||||||||||||||||||||||||||||
int i = i0 + threadIdx.y*(WARP_SIZE/QI2_0) + threadIdx.x/QI2_0; | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if (need_check) { | ||||||||||||||||||||||||||||||||||
i = min(i, i_max); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride; | ||||||||||||||||||||||||||||||||||
const int qs0 = get_int_b2(bxi->qs, kqsx); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
#pragma unroll | ||||||||||||||||||||||||||||||||||
for (int l = 0; l < QR2_0; ++l) { | ||||||||||||||||||||||||||||||||||
// 0..7, 32..39 | ||||||||||||||||||||||||||||||||||
// 8..15, 40..47 | ||||||||||||||||||||||||||||||||||
// 16..23, 48..55 | ||||||||||||||||||||||||||||||||||
// 24..31, 56..63 | ||||||||||||||||||||||||||||||||||
const int k = (kqsx/8)*32 + l*8 + kqsx % 8; | ||||||||||||||||||||||||||||||||||
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101); | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
On NVIDIA GPUs there are 32 shared memory banks with 4 bytes each. To get the maximum memory bandwidth each thread in a warp needs to read from/write to a different memory bank. So with this patch it should be one write to 32 banks instead of 4 writes to 8 banks. I did not actually try running or even compiling this code. The correct tool to use in this situation is NVIDIA NSight Compute and check whether the shared memory bank conflicts are actually fixed (useful to manually add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do see a very (very) small increase in performance when applying this change (which also happens to remain correct, congrats). I'll need to have a look with NVIDIA NSight Compute, then. I'm not yet familiar with how memory bank conflicts happen, so that's a good opportunity to learn. From what I can guess with your suggested change, it seems like here it's caused by writing 16 bytes per thread and something to do with the order they are written? (because this line only changes the order within a thread, which somehow matters?) This is my first time writing any CUDA kernels (which is why I've described the implementation as naïve), so thank you for mentioning the correct tools. I'll attempt to use that to check if there's still a bank conflict here or not, and then I'll get back to you. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are 32 threads in a warp and 32 memory banks with 4 bytes each. Each memory bank can be accessed in parallel which results in the maximum memory bandwidth. Each memory bank is responsible for all addresses where |
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
#ifdef INT8_MMA_AVAILABLE | ||||||||||||||||||||||||||||||||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = q; | ||||||||||||||||||||||||||||||||||
#else | ||||||||||||||||||||||||||||||||||
x_qs[i*(2*WARP_SIZE + 1) + k] = q; | ||||||||||||||||||||||||||||||||||
#endif // INT8_MMA_AVAILABLE | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
#pragma unroll | ||||||||||||||||||||||||||||||||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_0/2)) { | ||||||||||||||||||||||||||||||||||
int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_0) + threadIdx.x/(QI2_0/2); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if (need_check) { | ||||||||||||||||||||||||||||||||||
i = min(i, i_max); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride; | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
const int k = threadIdx.x % (QI2_0/2); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
#ifdef INT8_MMA_AVAILABLE | ||||||||||||||||||||||||||||||||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d; | ||||||||||||||||||||||||||||||||||
#else | ||||||||||||||||||||||||||||||||||
x_df[i*(WARP_SIZE/4) + i/4 + k] = bxi->d; | ||||||||||||||||||||||||||||||||||
#endif // INT8_MMA_AVAILABLE | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl( | ||||||||||||||||||||||||||||||||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
@@ -2427,6 +2492,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> { | |||||||||||||||||||||||||||||||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | ||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||||||||||||||||||||||||||||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_TQ2_0> { | ||||||||||||||||||||||||||||||||||
static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ; | ||||||||||||||||||||||||||||||||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0<mmq_y, nwarps, need_check>; | ||||||||||||||||||||||||||||||||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>; | ||||||||||||||||||||||||||||||||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>; | ||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check> | ||||||||||||||||||||||||||||||||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> { | ||||||||||||||||||||||||||||||||||
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; | ||||||||||||||||||||||||||||||||||
|
@@ -2916,6 +2989,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); | |||||||||||||||||||||||||||||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); | ||||||||||||||||||||||||||||||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); | ||||||||||||||||||||||||||||||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); | ||||||||||||||||||||||||||||||||||
extern DECL_MMQ_CASE(GGML_TYPE_TQ2_0); | ||||||||||||||||||||||||||||||||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); | ||||||||||||||||||||||||||||||||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); | ||||||||||||||||||||||||||||||||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); | ||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
// This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||
|
||
#include "../mmq.cuh" | ||
|
||
DECL_MMQ_CASE(GGML_TYPE_TQ2_0); |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -524,6 +524,32 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( | |
return d6 * sumf_d; | ||
} | ||
|
||
#define VDR_TQ2_0_Q8_1_MMVQ 2 | ||
#define VDR_TQ2_0_Q8_1_MMQ 8 | ||
|
||
// Can use the same for both mmvq and mmq, because there are no sub-scales in a TQ2_0 block | ||
template <int vdr> static __device__ __forceinline__ float vec_dot_tq2_0_q8_1_impl( | ||
const int * __restrict__ v, const int * __restrict__ u, const float & d2, const float * __restrict__ d8) { | ||
|
||
float sumf = 0.0f; | ||
|
||
#pragma unroll | ||
for (int i0 = 0; i0 < QR2_0; ++i0) { | ||
int sumi = 0; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < vdr; ++i) { | ||
Comment on lines
+537
to
+541
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. I think I tried to use a similar nomenclature as some of the other functions in this file. But I agree, |
||
const int vi = (v[i] >> (2*i0)) & 0x03030303; | ||
|
||
sumi = ggml_cuda_dp4a(__vsub4(vi, 0x01010101), u[vdr*i0 + i], sumi); // SIMD dot product | ||
} | ||
|
||
sumf += d8[i0] * sumi; | ||
} | ||
|
||
return d2 * sumf; | ||
} | ||
|
||
static __device__ __forceinline__ float vec_dot_q4_0_q8_1( | ||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { | ||
|
||
|
@@ -786,6 +812,37 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( | |
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); | ||
} | ||
|
||
static __device__ __forceinline__ float vec_dot_tq2_0_q8_1( | ||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { | ||
|
||
const block_tq2_0 * btq2_0 = (const block_tq2_0 *) vbq + kbx; | ||
|
||
// iqs 0..7 all need bq8_offset 0, 1, 2, 3 | ||
// iqs 8..15 all need bq8_offset 4, 5, 6, 7 | ||
const int bq8_offset = QR2_0 * (iqs / 8); | ||
|
||
int v[VDR_TQ2_0_Q8_1_MMVQ]; | ||
int u[QR2_0*VDR_TQ2_0_Q8_1_MMVQ]; | ||
float d8[QR2_0]; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) { | ||
v[i] = get_int_b2(btq2_0->qs, iqs + i); | ||
} | ||
|
||
#pragma unroll | ||
for (int i0 = 0; i0 < QR2_0; ++i0) { | ||
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i0; | ||
|
||
for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) { | ||
u[VDR_TQ2_0_Q8_1_MMVQ*i0 + i] = get_int_b4(bq8i->qs, (iqs % QI8_1) + i); | ||
} | ||
d8[i0] = __low2float(bq8i->ds); | ||
} | ||
|
||
return vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMVQ>(v, u, btq2_0->d, d8); | ||
} | ||
|
||
#define VDR_IQ2_XXS_Q8_1_MMVQ 2 | ||
#define VDR_IQ2_XXS_Q8_1_MMQ 2 | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3375,7 +3375,8 @@ static const ggml_type all_types[] = { | |
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, | ||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, | ||
GGML_TYPE_Q6_K, | ||
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends | ||
// GGML_TYPE_TQ1_0, | ||
GGML_TYPE_TQ2_0, | ||
Comment on lines
-3378
to
+3379
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An unintended side effect of un-commenting Some solutions are:
Most of these solutions (apart from hiding the problem) are out of scope of this PR which focuses on the CUDA implementation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The correct fix would be to modify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in b6fc9f0. |
||
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, | ||
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, | ||
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, | ||
|
@@ -3387,6 +3388,7 @@ static const ggml_type base_types[] = { | |
GGML_TYPE_Q4_0, | ||
GGML_TYPE_Q4_1, // for I8MM tests | ||
GGML_TYPE_Q4_K, | ||
GGML_TYPE_TQ2_0, | ||
GGML_TYPE_IQ2_XXS | ||
}; | ||
|
||
|
@@ -3397,7 +3399,8 @@ static const ggml_type other_types[] = { | |
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, | ||
GGML_TYPE_Q5_K, | ||
GGML_TYPE_Q6_K, | ||
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends | ||
// GGML_TYPE_TQ1_0, | ||
GGML_TYPE_TQ2_0, | ||
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, | ||
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, | ||
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be faster but since this kernel is going to be I/O bound anyways I doubt it will make a measurable difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the indices calculation shouldn't really be a bottleneck here.
Is there a particular reason why
tid
isn't anint
everywhere in that file when it corresponds tothreadIdx.x
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you mean my comment that was just me being a bit inconsistent and not looking ahead how the values are being used, sorry. Generally speaking the issue with
int
vs.int64_t
is just potential overflows for very large tensors. So for kernels where the performance is not relevant anyways it's a lot of the time preferable to just useint64_t
.