-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml-cuda : add TQ2_0 kernels, for ternary inference on GPU #11183
base: master
Are you sure you want to change the base?
Conversation
This also removes custom TQ2_0 mmq dp4a, because re-using the one from Q8_0 allows avoiding to repeatedly unpack the 2-bit values to 8-bit and instead only do it once per tile.
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends | ||
// GGML_TYPE_TQ1_0, | ||
GGML_TYPE_TQ2_0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An unintended side effect of un-commenting TQ2_0
here makes the Metal tests fail, as in https://github.com/ggerganov/llama.cpp/actions/runs/12716518343/job/35451025034?pr=11183#step:5:13921, because operations on that type are not yet implemented there and the ggml_metal_supports_op
function isn't representative of the types supported by the Metal backend.
Some solutions are:
- Implement all relevant
TQ2_0
support for Metal- Will happen eventually, a starting point already floats around somewhere in a branch linked in ggml-quants : ternary packing for TriLMs and BitNet b1.58 #8151 (comment).
- Make the
ggml_metal_supports_op
correctly returnfalse
when it should- Should be done for correctness
- An "easy" way to temporarily do this would be similar to what was done for BF16 and simply return
false
when aTQ2_0
tensor is encountered. The same should be done for the other not-yet-supported types likeTQ1_0
.
- Avoid testing
TQ2_0
to hide the error- This doesn't fix the problem.
Most of these solutions (apart from hiding the problem) are out of scope of this PR which focuses on the CUDA implementation of TQ2_0
. But I don't want this to make the Metal CI fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The correct fix would be to modify ggml_metal_supports_op
since even apart from these tests a TQ2_0
tensor is going to result in a crash. I don't have or want any Apple hardware myself but it should be fairly easy to just modify the switch statement and I think it can be done in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in b6fc9f0.
const int64_t tid = threadIdx.x; // 0..64 | ||
const int64_t n = tid/32; // 0 or 1 | ||
const int64_t l = tid - 32*n; // 0..32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const int64_t tid = threadIdx.x; // 0..64 | |
const int64_t n = tid/32; // 0 or 1 | |
const int64_t l = tid - 32*n; // 0..32 | |
const int tid = threadIdx.x; // 0..64 | |
const int64_t n = tid/32; // 0 or 1 | |
const int64_t l = tid & 0x1F; // tid - 32*n, 0..32 |
This should be faster but since this kernel is going to be I/O bound anyways I doubt it will make a measurable difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the indices calculation shouldn't really be a bottleneck here.
Is there a particular reason why tid
isn't an int
everywhere in that file when it corresponds to threadIdx.x
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you mean my comment that was just me being a bit inconsistent and not looking ahead how the values are being used, sorry. Generally speaking the issue with int
vs. int64_t
is just potential overflows for very large tensors. So for kernels where the performance is not relevant anyways it's a lot of the time preferable to just use int64_t
.
ggml/src/ggml-cuda/mmq.cuh
Outdated
for (int l = 0; l < QR2_0; ++l) { | ||
// 0..7, 32..39 | ||
// 8..15, 40..47 | ||
// 16..23, 48..55 | ||
// 24..31, 56..63 | ||
const int k = (kqsx/8)*32 + l*8 + kqsx % 8; | ||
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (int l = 0; l < QR2_0; ++l) { | |
// 0..7, 32..39 | |
// 8..15, 40..47 | |
// 16..23, 48..55 | |
// 24..31, 56..63 | |
const int k = (kqsx/8)*32 + l*8 + kqsx % 8; | |
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101); | |
for (int l0 = 0; l0 < QR2_0; ++l0) { | |
const int l = (l0 + kqsx/8) % QR2_0; // avoid shared memory bank conflicts | |
// 0..7, 32..39 | |
// 8..15, 40..47 | |
// 16..23, 48..55 | |
// 24..31, 56..63 | |
const int k = (kqsx/8)*32 + l*8 + kqsx % 8; | |
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101); |
On NVIDIA GPUs there are 32 shared memory banks with 4 bytes each. To get the maximum memory bandwidth each thread in a warp needs to read from/write to a different memory bank. So with this patch it should be one write to 32 banks instead of 4 writes to 8 banks. I did not actually try running or even compiling this code. The correct tool to use in this situation is NVIDIA NSight Compute and check whether the shared memory bank conflicts are actually fixed (useful to manually add -lineinfo
to NVCC args so the tool can associate source lines with the PTX code). If you are unfamiliar with the tool I can look at it for you (please let me know).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do see a very (very) small increase in performance when applying this change (which also happens to remain correct, congrats).
I'll need to have a look with NVIDIA NSight Compute, then. I'm not yet familiar with how memory bank conflicts happen, so that's a good opportunity to learn. From what I can guess with your suggested change, it seems like here it's caused by writing 16 bytes per thread and something to do with the order they are written? (because this line only changes the order within a thread, which somehow matters?)
This is my first time writing any CUDA kernels (which is why I've described the implementation as naïve), so thank you for mentioning the correct tools. I'll attempt to use that to check if there's still a bank conflict here or not, and then I'll get back to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are 32 threads in a warp and 32 memory banks with 4 bytes each. Each memory bank can be accessed in parallel which results in the maximum memory bandwidth. Each memory bank is responsible for all addresses where (address / 4) % 32
is its index. So the easiest way to get maximum memory bandwidth is to just access 128 contiguous bytes. In the original code the offsets for each loop iteration were 128 bytes for groups of 8 threads. Because of this all writes in a warp ended up in the same 8 memory banks and you needed 4 writes instead of 1. The change I proposed simply changes the order in which the data is written so that on each iteration each group writes to different memory banks.
for (int i0 = 0; i0 < QR2_0; ++i0) { | ||
int sumi = 0; | ||
|
||
#pragma unroll | ||
for (int i = 0; i < vdr; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think i0
and i
are a bit confusing in terms of names, I would prefer something like i
and j
(but this is a very minor issue).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I think I tried to use a similar nomenclature as some of the other functions in this file. Q6_K
, Q3_K
and Q2_K
also use i0
and i
.
But I agree, i
and j
are less confusing. (changed in fbddb26)
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends | ||
// GGML_TYPE_TQ1_0, | ||
GGML_TYPE_TQ2_0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The correct fix would be to modify ggml_metal_supports_op
since even apart from these tests a TQ2_0
tensor is going to result in a crash. I don't have or want any Apple hardware myself but it should be fairly easy to just modify the switch statement and I think it can be done in this PR.
Co-authored-by: Johannes Gäßler <[email protected]>
Maybe not the cleanest way, but hopefully temporary.
Just so we're on the same page: do you intend to look into NSight Compute in the scope of this PR? |
@JohannesGaessler I did not see any reduction in On second thought, it may be best to improve the performance separately, since at least this works and it's relatively self-contained. In any case, I'll still think about this for a few days and I'll also try the graphical interface of NSight Compute to identify if there are other places where I caused bank conflicts and/or if there are other bottlenecks. |
I myself use the GUI, I think it's pretty intuitive. |
@compilade do you know of a way to convert this model into the gguf format? I think these: (1B, 3B, 7B, 10B) are supported over in the bitnet repository. It would be interesting to see GPU benchmarks for the largest models, perhaps I could help with that.
|
@BarfingLemurs Yes, I might know how to. I didn't try that model yet, though. It's likely similar to https://huggingface.co/HF1BitLLM/Llama3-8B-1.58-100B-tokens/discussions/3, but since there are now quite a few of these models packed with (In a separate PR, of course) |
Follow-up to #8151, which added ternary types (although CPU-only at first), this implements CUDA kernels for
TQ2_0
(mmvq, tile loading for mmq and mma, and dequant-based cuBLAS).(Although there was a similar effort in ikawrakow/ik_llama.cpp#13 by @ikawrakow, mmq wasn't handled there, but here, it is.)
Perplexity
Note that generation quality may differ slighly from CPU inference because the CUDA
TQ2_0
kernels useQ8_1
(32 int8 weights per scale) as the activation type, while on CPU,Q8_K
is used (256 int8 weights per scale).The perplexities below were calculated with TriLM-3.9B on
wiki.test.raw
fromwikitext-2-raw
, when using the CUDA backend.F16
F16
andF16
TQ2_0
F16
andF16
TQ2_0
Q4_K
andQ6_K
Performance
It's fast. But there is still room for improvement. The implementation is relatively naïve.
Commands used for the benchmarks below
For
tg128
:$ ./bin/llama-bench -m ../models/trilm/TriLM_3.9B_Unpacked-TQ2_0.gguf -n 128 -p 0 -r 20
For
pp2048
:$ ./bin/llama-bench -m ../models/trilm/TriLM_3.9B_Unpacked-TQ2_0.gguf -b 4,8,16,32,64,128,256,512,1024,2048 -ub 2048 -n 0 -p 2048 -r 10
And again for each tested quant type.
Tokens per second for TriLM-3.9B comparing
TQ2_0
and various quant types on a NVIDIA GeForce RTX 3090 (using the CUDA backend):(best of each row is in bold)
n_batch
andn_ubatch
TQ2_0
Q2_K
Q4_0
Q4_K_M
Q8_0
F16
The same tests, with the same 3.9B ternary model, using a NVIDIA GeForce RTX 4090:
n_batch
andn_ubatch
TQ2_0
Q2_K
Q4_0
Q4_K_M
Q8_0
F16
There is a noticeable relative speedup compared to larger types at low batch sizes (e.g. when doing single-user text generation like in
tg128
). Of course, there is still room for improvement.(
TQ1_0
is out of scope of this PR, but GPU support for it will also come eventually)