Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml-cuda : add TQ2_0 kernels, for ternary inference on GPU #11183

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

compilade
Copy link
Collaborator

Follow-up to #8151, which added ternary types (although CPU-only at first), this implements CUDA kernels for TQ2_0 (mmvq, tile loading for mmq and mma, and dequant-based cuBLAS).

(Although there was a similar effort in ikawrakow/ik_llama.cpp#13 by @ikawrakow, mmq wasn't handled there, but here, it is.)

Perplexity

Note that generation quality may differ slighly from CPU inference because the CUDA TQ2_0 kernels use Q8_1 (32 int8 weights per scale) as the activation type, while on CPU, Q8_K is used (256 int8 weights per scale).

The perplexities below were calculated with TriLM-3.9B on wiki.test.raw from wikitext-2-raw, when using the CUDA backend.

Quant Token embeddings and output tensor types Perplexity
F16 F16 and F16 11.1508 +/- 0.07852
TQ2_0 F16 and F16 11.1517 +/- 0.07853
TQ2_0 Q4_K and Q6_K 11.1539 +/- 0.07852

Performance

It's fast. But there is still room for improvement. The implementation is relatively naïve.

Commands used for the benchmarks below

For tg128:

$ ./bin/llama-bench -m ../models/trilm/TriLM_3.9B_Unpacked-TQ2_0.gguf -n 128 -p 0 -r 20

For pp2048:

$ ./bin/llama-bench -m ../models/trilm/TriLM_3.9B_Unpacked-TQ2_0.gguf -b 4,8,16,32,64,128,256,512,1024,2048 -ub 2048 -n 0 -p 2048 -r 10

And again for each tested quant type.


Tokens per second for TriLM-3.9B comparing TQ2_0 and various quant types on a NVIDIA GeForce RTX 3090 (using the CUDA backend):

(best of each row is in bold)

test n_batch and n_ubatch TQ2_0 Q2_K Q4_0 Q4_K_M Q8_0 F16
tg128 1 262.27 ± 1.49 201.56 ± 1.00 215.75 ± 0.97 204.23 ± 0.50 146.94 ± 0.25 95.14 ± 0.26
pp2048 4 606.14 ± 0.24 459.33 ± 1.88 592.91 ± 0.28 476.91 ± 2.67 476.89 ± 0.77 298.54 ± 0.12
pp2048 8 830.37 ± 1.49 662.66 ± 1.09 854.81 ± 2.38 637.81 ± 2.19 684.99 ± 0.18 599.84 ± 0.26
pp2048 16 1923.61 ± 0.84 1678.69 ± 0.92 1624.60 ± 0.23 1630.84 ± 0.83 1356.20 ± 0.73 1142.86 ± 0.26
pp2048 32 3260.83 ± 0.83 2708.09 ± 0.21 2758.63 ± 0.23 2825.91 ± 0.33 2512.09 ± 1.33 2187.86 ± 0.42
pp2048 64 4653.85 ± 2.89 3637.53 ± 0.89 4146.47 ± 4.02 4026.11 ± 0.81 3874.62 ± 5.72 3874.89 ± 7.65
pp2048 128 5542.26 ± 10.78 3697.21 ± 0.73 5192.41 ± 13.34 4804.31 ± 0.94 5012.78 ± 13.91 5693.27 ± 11.47
pp2048 256 6594.14 ± 8.64 4676.00 ± 23.15 6185.45 ± 17.64 5843.54 ± 6.90 6156.16 ± 16.44 6733.23 ± 19.90
pp2048 512 6964.92 ± 25.13 5165.61 ± 25.91 6514.26 ± 15.43 6172.70 ± 8.62 6513.97 ± 8.07 6913.39 ± 21.24
pp2048 1024 6988.18 ± 23.12 5382.19 ± 32.51 6534.47 ± 25.53 6246.18 ± 24.18 6558.76 ± 18.24 6783.25 ± 15.27
pp2048 2048 6558.82 ± 10.95 5218.12 ± 15.81 6143.25 ± 9.76 5940.68 ± 15.27 6204.75 ± 14.12 6524.19 ± 7.07

The same tests, with the same 3.9B ternary model, using a NVIDIA GeForce RTX 4090:

test n_batch and n_ubatch TQ2_0 Q2_K Q4_0 Q4_K_M Q8_0 F16
tg128 1 330.38 ± 0.46 285.66 ± 0.71 241.07 ± 0.47 231.95 ± 0.52 162.51 ± 0.25 103.10 ± 0.06
pp2048 4 969.87 ± 7.71 780.09 ± 3.80 819.69 ± 3.12 748.16 ± 4.38 588.14 ± 1.97 349.37 ± 0.44
pp2048 8 1397.01 ± 4.09 1232.43 ± 4.01 1335.88 ± 7.21 1180.74 ± 2.29 1017.97 ± 2.36 686.39 ± 1.86
pp2048 16 2975.74 ± 15.60 2643.16 ± 6.86 2364.93 ± 3.57 2353.31 ± 11.33 1865.25 ± 7.49 1172.29 ± 5.04
pp2048 32 5204.42 ± 0.97 4430.29 ± 2.91 4196.54 ± 22.09 4324.71 ± 22.95 3470.22 ± 0.71 2628.78 ± 8.86
pp2048 64 8312.00 ± 41.84 6819.06 ± 36.22 7111.68 ± 28.82 6978.54 ± 13.87 5935.29 ± 1.36 4865.42 ± 0.83
pp2048 128 10958.72 ± 77.77 7526.91 ± 27.85 10054.20 ± 1.54 9341.22 ± 44.48 8879.21 ± 1.23 7541.59 ± 64.49
pp2048 256 14145.32 ± 65.05 9294.30 ± 57.37 13194.65 ± 2.25 12198.99 ± 88.07 12612.84 ± 71.58 11319.90 ± 5.33
pp2048 512 15346.04 ± 19.07 10761.67 ± 45.84 14350.62 ± 80.40 13610.42 ± 50.48 14215.50 ± 16.16 13252.30 ± 15.23
pp2048 1024 14236.63 ± 4.37 11092.28 ± 8.33 13210.68 ± 69.18 12785.47 ± 21.77 13277.57 ± 65.59 12670.69 ± 61.22
pp2048 2048 11890.03 ± 79.01 9992.28 ± 23.62 11208.72 ± 51.92 11006.58 ± 83.24 11375.74 ± 39.91 10722.65 ± 45.49

There is a noticeable relative speedup compared to larger types at low batch sizes (e.g. when doing single-user text generation like in tg128). Of course, there is still room for improvement.

(TQ1_0 is out of scope of this PR, but GPU support for it will also come eventually)

This also removes custom TQ2_0 mmq dp4a,
because re-using the one from Q8_0 allows avoiding
to repeatedly unpack the 2-bit values to 8-bit
and instead only do it once per tile.
@compilade compilade added enhancement New feature or request performance Speed related topics Review Complexity : High Generally require indepth knowledge of LLMs or GPUs ggml changes relating to the ggml tensor library for machine learning labels Jan 10, 2025
@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs python python script changes labels Jan 10, 2025
Comment on lines -3378 to +3379
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
// GGML_TYPE_TQ1_0,
GGML_TYPE_TQ2_0,
Copy link
Collaborator Author

@compilade compilade Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An unintended side effect of un-commenting TQ2_0 here makes the Metal tests fail, as in https://github.com/ggerganov/llama.cpp/actions/runs/12716518343/job/35451025034?pr=11183#step:5:13921, because operations on that type are not yet implemented there and the ggml_metal_supports_op function isn't representative of the types supported by the Metal backend.

Some solutions are:

  • Implement all relevant TQ2_0 support for Metal
  • Make the ggml_metal_supports_op correctly return false when it should
    • Should be done for correctness
    • An "easy" way to temporarily do this would be similar to what was done for BF16 and simply return false when a TQ2_0 tensor is encountered. The same should be done for the other not-yet-supported types like TQ1_0.
  • Avoid testing TQ2_0 to hide the error
    • This doesn't fix the problem.

Most of these solutions (apart from hiding the problem) are out of scope of this PR which focuses on the CUDA implementation of TQ2_0. But I don't want this to make the Metal CI fail.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct fix would be to modify ggml_metal_supports_op since even apart from these tests a TQ2_0 tensor is going to result in a crash. I don't have or want any Apple hardware myself but it should be fairly easy to just modify the switch statement and I think it can be done in this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in b6fc9f0.

Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Jan 10, 2025
Comment on lines +286 to +288
const int64_t tid = threadIdx.x; // 0..64
const int64_t n = tid/32; // 0 or 1
const int64_t l = tid - 32*n; // 0..32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const int64_t tid = threadIdx.x; // 0..64
const int64_t n = tid/32; // 0 or 1
const int64_t l = tid - 32*n; // 0..32
const int tid = threadIdx.x; // 0..64
const int64_t n = tid/32; // 0 or 1
const int64_t l = tid & 0x1F; // tid - 32*n, 0..32

This should be faster but since this kernel is going to be I/O bound anyways I doubt it will make a measurable difference.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the indices calculation shouldn't really be a bottleneck here.

Is there a particular reason why tid isn't an int everywhere in that file when it corresponds to threadIdx.x?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean my comment that was just me being a bit inconsistent and not looking ahead how the values are being used, sorry. Generally speaking the issue with int vs. int64_t is just potential overflows for very large tensors. So for kernels where the performance is not relevant anyways it's a lot of the time preferable to just use int64_t.

ggml/src/ggml-cuda/convert.cu Outdated Show resolved Hide resolved
Comment on lines 1840 to 1846
for (int l = 0; l < QR2_0; ++l) {
// 0..7, 32..39
// 8..15, 40..47
// 16..23, 48..55
// 24..31, 56..63
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (int l = 0; l < QR2_0; ++l) {
// 0..7, 32..39
// 8..15, 40..47
// 16..23, 48..55
// 24..31, 56..63
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);
for (int l0 = 0; l0 < QR2_0; ++l0) {
const int l = (l0 + kqsx/8) % QR2_0; // avoid shared memory bank conflicts
// 0..7, 32..39
// 8..15, 40..47
// 16..23, 48..55
// 24..31, 56..63
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);

On NVIDIA GPUs there are 32 shared memory banks with 4 bytes each. To get the maximum memory bandwidth each thread in a warp needs to read from/write to a different memory bank. So with this patch it should be one write to 32 banks instead of 4 writes to 8 banks. I did not actually try running or even compiling this code. The correct tool to use in this situation is NVIDIA NSight Compute and check whether the shared memory bank conflicts are actually fixed (useful to manually add -lineinfo to NVCC args so the tool can associate source lines with the PTX code). If you are unfamiliar with the tool I can look at it for you (please let me know).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do see a very (very) small increase in performance when applying this change (which also happens to remain correct, congrats).

I'll need to have a look with NVIDIA NSight Compute, then. I'm not yet familiar with how memory bank conflicts happen, so that's a good opportunity to learn. From what I can guess with your suggested change, it seems like here it's caused by writing 16 bytes per thread and something to do with the order they are written? (because this line only changes the order within a thread, which somehow matters?)

This is my first time writing any CUDA kernels (which is why I've described the implementation as naïve), so thank you for mentioning the correct tools. I'll attempt to use that to check if there's still a bank conflict here or not, and then I'll get back to you.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 32 threads in a warp and 32 memory banks with 4 bytes each. Each memory bank can be accessed in parallel which results in the maximum memory bandwidth. Each memory bank is responsible for all addresses where (address / 4) % 32 is its index. So the easiest way to get maximum memory bandwidth is to just access 128 contiguous bytes. In the original code the offsets for each loop iteration were 128 bytes for groups of 8 threads. Because of this all writes in a warp ended up in the same 8 memory banks and you needed 4 writes instead of 1. The change I proposed simply changes the order in which the data is written so that on each iteration each group writes to different memory banks.

Comment on lines +537 to +541
for (int i0 = 0; i0 < QR2_0; ++i0) {
int sumi = 0;

#pragma unroll
for (int i = 0; i < vdr; ++i) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i0 and i are a bit confusing in terms of names, I would prefer something like i and j (but this is a very minor issue).

Copy link
Collaborator Author

@compilade compilade Jan 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I think I tried to use a similar nomenclature as some of the other functions in this file. Q6_K, Q3_K and Q2_K also use i0 and i.

But I agree, i and j are less confusing. (changed in fbddb26)

Comment on lines -3378 to +3379
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
// GGML_TYPE_TQ1_0,
GGML_TYPE_TQ2_0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct fix would be to modify ggml_metal_supports_op since even apart from these tests a TQ2_0 tensor is going to result in a crash. I don't have or want any Apple hardware myself but it should be fairly easy to just modify the switch statement and I think it can be done in this PR.

@github-actions github-actions bot added the Apple Metal https://en.wikipedia.org/wiki/Metal_(API) label Jan 12, 2025
@JohannesGaessler
Copy link
Collaborator

Just so we're on the same page: do you intend to look into NSight Compute in the scope of this PR?

@compilade
Copy link
Collaborator Author

Just so we're on the same page: do you intend to look into NSight Compute in the scope of this PR?

@JohannesGaessler
Yes, that was my intention, although this PR is pretty much done otherwise.
It's going to take me some time, though, to learn how to use NSight Compute, and especially since I don't have a good local GPU yet because I'm still trying out remote GPUs before settling on something better.

I did not see any reduction in l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum between your proposed change to the indexing in load_tiles_tq2_0 and the version before, although I might have measured the wrong things, because the benchmark speeds did seem to differ slightly (and still do).

On second thought, it may be best to improve the performance separately, since at least this works and it's relatively self-contained.

In any case, I'll still think about this for a few days and I'll also try the graphical interface of NSight Compute to identify if there are other places where I caused bank conflicts and/or if there are other bottlenecks.

@JohannesGaessler
Copy link
Collaborator

I myself use the GUI, I think it's pretty intuitive.

@BarfingLemurs
Copy link
Contributor

@compilade do you know of a way to convert this model into the gguf format?
https://huggingface.co/tiiuae/Falcon3-10B-Instruct-1.58bit

I think these: (1B, 3B, 7B, 10B) are supported over in the bitnet repository.

It would be interesting to see GPU benchmarks for the largest models, perhaps I could help with that.

P. S., found this tidbit recently: ridgerchu/matmulfreellm#33
Perhaps this work could benefit this model when they release: https://github.com/Chenglin-Yang/1.58bit.flux/issues

@compilade
Copy link
Collaborator Author

compilade commented Jan 22, 2025

@BarfingLemurs Yes, I might know how to. I didn't try that model yet, though.

It's likely similar to https://huggingface.co/HF1BitLLM/Llama3-8B-1.58-100B-tokens/discussions/3, but since there are now quite a few of these models packed with "quant_method": "bitnet", I think I should make the convert script more cleanly handle that (and also make it easier to transparently support other pre-quantizations) so that it's less ad-hoc.

(In a separate PR, of course)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) enhancement New feature or request ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs performance Speed related topics python python script changes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants