Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: mmq CLI option, fixed mmq build issues #2453

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

This PR aims to fix the following:

  • The new quantized matrix matrix multiplication kernels decide whether or not to unroll a loop based on compute capability but with cmake this check is not functional by default. This PR adds PTX code for compute capability 7.0.
  • For GPUs with compute capability < 6.1 the new quantized matrix matrix multiplication kernels cannot be used due to a lack of the __dp4a intrinsic. This PR adds a corresponding check.
  • For the perplexity binary cuBLAS should always be used as suggested by CUDA: Quantized matrix matrix multiplication #2160 (comment) . I tried to add a corresponding check to cmake but adding the define at that position does nothing. Is it actually possible to do this with cmake?

@Green-Sky
Copy link
Collaborator

Is it actually possible to do this with cmake?

this is going to be hard with single config cmake projects, if not impossible.
for ci, you could run cmake + build 2 times, one without ppl and one only ppl.

@slaren
Copy link
Member

slaren commented Jul 30, 2023

About the cuBLAS issue, I think it may be better to make it a command line option, there isn't much reason for this to be a compile-time option for now, since we still have to link to cuBLAS regardless

@Green-Sky
Copy link
Collaborator

since we still have to link to cuBLAS regardless

we do?

@JohannesGaessler
Copy link
Collaborator Author

Currently yes because the KV cache uses f16 for which there is no mul_mat_q implementation.

@JohannesGaessler
Copy link
Collaborator Author

Okay, so the reason I made this a compile option instead of a command line option is because I did not want to add temporary changes to the command line options. How about this: a command line option like --mul-mat-algorithm that takes values like blas, mul_mat_q (open for better names), or auto (decide based on current performance level). If at some point we remove cuBLAS as a mandatory dependency and the user compiles without cuBLAS the behavior of auto is adjusted and explicitly requesting cuBLAS results in an error.

@slaren
Copy link
Member

slaren commented Jul 30, 2023

We already have 2 temporary parameters for llama 2, so I don't think adding another one would be too bad. My opinion is that something like --mul-mat-algorithm shouldn't be a parameter in the long term either, this is not something that the user should have to decide. So if it is going to be a temporary parameter regardless, it may as well be a simple cublas yes/cublas no.

@JohannesGaessler JohannesGaessler marked this pull request as ready for review July 31, 2023 09:57
@JohannesGaessler
Copy link
Collaborator Author

I changed the new kernels to be opt-in. They can be used with the CLI option --mul-mat-q.

@JohannesGaessler JohannesGaessler force-pushed the cuda-mmq-build-fixes branch 3 times, most recently from 8447bc8 to 3d0e531 Compare July 31, 2023 10:59
@JohannesGaessler JohannesGaessler mentioned this pull request Jul 31, 2023
@JohannesGaessler JohannesGaessler changed the title CUDA: fixed mmq build issues CUDA: mmq CLI option, fixed mmq build issues Jul 31, 2023
@slaren
Copy link
Member

slaren commented Jul 31, 2023

Quick test with 3090 Ti under WSL2:

Model cuBLAS pp t/s mmq pp t/s
llama-2-7b.ggmlv3.q4_0.bin 1473.09 1544.40
llama-2-7b.ggmlv3.q5_K_M.bin 1461.38 875.14

Note: this is with master merged.

@JohannesGaessler JohannesGaessler merged commit 0728c5a into ggml-org:master Jul 31, 2023
@Dampfinchen
Copy link

Dampfinchen commented Jul 31, 2023

@JohannesGaessler I'm blown away.

speeed 13B, 2060, 16 layers.

You managed to code your matmul kernel in such a way that prompt processing is faster now than cublas (14,85 ms/t), atleast for my q5_1 model.

With yesterday's implementation at 17 layers, prompt processing was singnificantly slower than cublas (20 ms/t in that case)

after 17 layers

Awesome job! I don't know how you did it in such a short time, but you did it.

@LostRuins
Copy link
Collaborator

I think line https://github.com/ggerganov/llama.cpp/blob/0728c5a8b9569183ffca0399caac099afef87595/CMakeLists.txt#L269 should be GGML_CUDA_F16 as GGML_CUDA_DMMV_F16 does not seem to exist in the code.

@LostRuins
Copy link
Collaborator

Also, performance wise I did observe a consistent speedup for q4_0, and a slowdown for q4_k_m (the 2 formats that I usually use)

@JohannesGaessler
Copy link
Collaborator Author

I don't know how you did it in such a short time, but you did it.

I did it by doing nothing else for an entire day because all of my physics colleagues are on vacation.

Also, performance wise I did observe a consistent speedup for q4_0, and a slowdown for q4_k_m (the 2 formats that I usually use)

Which GPU do you have?

@Dampfinchen
Copy link

I've made a similar observation. While the speed for the regular quants is great, the slowdown for k-quants is definately very noticeable. Here is a 7b model with 20 GPU layers, comparing q4_0 and q4k_m on an RTX 2060 laptop:

7b4_0 20 layers
7bq4k_m

@LostRuins
Copy link
Collaborator

Which GPU do you have?

I have a laptop Nvidia RTX 2060 6GB.

7B q4_K_M, Prompt processing on 1968 tokens, 32 layers offloaded.

non-mmq takes 14.7s (7ms/T)
mmq takes 18.5s (9ms/T)

@JohannesGaessler
Copy link
Collaborator Author

Can you try disabling this loop unroll? I only have an RTX 3090 and several Pascal cards on which the unroll was significantly slower (by a factor of 4). So it would be useful for me if someone with an RTX 2000 GPU could actually test the unroll (but so far it looks like it's not an issue).

@LostRuins
Copy link
Collaborator

LostRuins commented Aug 1, 2023

Disabling the loop unroll seemed to have a very slight negative performance impact:

Enabled: 19.1s (9ms/T)
Disabled: 20.1s (10ms/T)

However there was another interesting side effect: the binary size changed significantly, removing the highlighted loop unrolling code:
image
reduced the output binary file size by nearly 15%!

Nonetheless, both approaches are still significantly inferior to old method (non-MMQ):

Non-MMQ: 15.3s (8ms/T)

more concerning is the mmq changes breaking compatibility with full-offload offload for non-standard (vocab!=32000) models, even when mmq is disabled (ref: #2160 (comment))

@JohannesGaessler
Copy link
Collaborator Author

more concerning is the mmq changes breaking compatibility with full-offload offload for non-standard (vocab!=32000) models, even when mmq is disabled

I know what the problem is, I'll fix it soon.

@Dampfinchen
Copy link

@JohannesGaessler Alright, so I can confirm without a shadow of a doubt, the new mul-mat is faster than cublas at processing regular quants, but slower for k-quants. Here is the data to back it up:

First, let's start with 7b, 20 layers offloaded on an RTX2060 and q4_0.

This is with MMQ:

4_0mmq

This is without MMQ:

4_0nommq

Notice prompt processing speed is a little faster with the new mul mat.

Now, let's switch over to q4k_m.

First, MMQ enabled:

q4k_m_mulmatq

This is with regular cublas:

q4k_m-nommq

Now you will notice with k-quants its a different story. Prompt processing is slower than cuBLAS.

Regarding loop unroll, I've been similar results to LostRuins.

If you need me to test anything, feel free to ask! I'm happy to help.

bretello added a commit to bretello/llama-cpp-python that referenced this pull request Aug 3, 2023
This also fixes a crash when loading the 70b llama2 model.

This parameter was introduced in ggml-org/llama.cpp#2453 (`0728c5a8`)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants