-
Notifications
You must be signed in to change notification settings - Fork 10.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: mmq CLI option, fixed mmq build issues #2453
CUDA: mmq CLI option, fixed mmq build issues #2453
Conversation
this is going to be hard with single config cmake projects, if not impossible. |
About the cuBLAS issue, I think it may be better to make it a command line option, there isn't much reason for this to be a compile-time option for now, since we still have to link to cuBLAS regardless |
we do? |
Currently yes because the KV cache uses f16 for which there is no mul_mat_q implementation. |
Okay, so the reason I made this a compile option instead of a command line option is because I did not want to add temporary changes to the command line options. How about this: a command line option like |
We already have 2 temporary parameters for llama 2, so I don't think adding another one would be too bad. My opinion is that something like |
305b304
to
d2ff8dc
Compare
I changed the new kernels to be opt-in. They can be used with the CLI option |
8447bc8
to
3d0e531
Compare
3d0e531
to
5b5f04b
Compare
Quick test with 3090 Ti under WSL2:
Note: this is with master merged. |
@JohannesGaessler I'm blown away. You managed to code your matmul kernel in such a way that prompt processing is faster now than cublas (14,85 ms/t), atleast for my q5_1 model. With yesterday's implementation at 17 layers, prompt processing was singnificantly slower than cublas (20 ms/t in that case) Awesome job! I don't know how you did it in such a short time, but you did it. |
I think line https://github.com/ggerganov/llama.cpp/blob/0728c5a8b9569183ffca0399caac099afef87595/CMakeLists.txt#L269 should be |
Also, performance wise I did observe a consistent speedup for q4_0, and a slowdown for q4_k_m (the 2 formats that I usually use) |
I did it by doing nothing else for an entire day because all of my physics colleagues are on vacation.
Which GPU do you have? |
I have a laptop Nvidia RTX 2060 6GB. 7B q4_K_M, Prompt processing on 1968 tokens, 32 layers offloaded. non-mmq takes 14.7s (7ms/T) |
Can you try disabling this loop unroll? I only have an RTX 3090 and several Pascal cards on which the unroll was significantly slower (by a factor of 4). So it would be useful for me if someone with an RTX 2000 GPU could actually test the unroll (but so far it looks like it's not an issue). |
Disabling the loop unroll seemed to have a very slight negative performance impact: Enabled: 19.1s (9ms/T) However there was another interesting side effect: the binary size changed significantly, removing the highlighted loop unrolling code: Nonetheless, both approaches are still significantly inferior to old method (non-MMQ): Non-MMQ: 15.3s (8ms/T) more concerning is the mmq changes breaking compatibility with full-offload offload for non-standard (vocab!=32000) models, even when mmq is disabled (ref: #2160 (comment)) |
I know what the problem is, I'll fix it soon. |
@JohannesGaessler Alright, so I can confirm without a shadow of a doubt, the new mul-mat is faster than cublas at processing regular quants, but slower for k-quants. Here is the data to back it up: First, let's start with 7b, 20 layers offloaded on an RTX2060 and q4_0. This is with MMQ: This is without MMQ: Notice prompt processing speed is a little faster with the new mul mat. Now, let's switch over to q4k_m. First, MMQ enabled: This is with regular cublas: Now you will notice with k-quants its a different story. Prompt processing is slower than cuBLAS. Regarding loop unroll, I've been similar results to LostRuins. If you need me to test anything, feel free to ask! I'm happy to help. |
This also fixes a crash when loading the 70b llama2 model. This parameter was introduced in ggml-org/llama.cpp#2453 (`0728c5a8`)
This PR aims to fix the following:
__dp4a
intrinsic. This PR adds a corresponding check.perplexity
binary cuBLAS should always be used as suggested by CUDA: Quantized matrix matrix multiplication #2160 (comment) . I tried to add a corresponding check to cmake but adding the define at that position does nothing. Is it actually possible to do this with cmake?