-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml: aarch64: implement mmla kernels for q8_0_q8_0, q4_0_q8_0 and q4_1_q8_1 quantized gemm #4966
Conversation
I have tested this PR with few Llama2 models using different prompt sizes, compared the gemm output from mmla kernel and the default dot kernels and confirmed they were matching. Please let me know if there are any unit tests or perplexity tests I need to run for this PR. thank you! |
d924089
to
2790dda
Compare
2790dda
to
72cad33
Compare
72cad33
to
9859c5b
Compare
Interesting work - I was not familiar with these instructions. Gaining performance for CPU-based prompt processing is of significant interest. Looks like you want to process 2 rows at a time with a single kernel call. However, I feel this is implemented in a very convoluted way with quite a lot of duplication in the matrix multiplication logic. I could be missing something though Can you try to fit this into the existing Also run a # get wikitext test data
./scripts/get-wikitext-2.sh
unzip wikitext-2-raw-v1.zip
# run test (takes a while)
./perplexity -m some-model-q4_0.gguf -f ./wikitext-2-raw/wiki.test.raw |
@ggerganov , thanks for the feedback. Yes, your understanding is correct; i'm processing two rows and columns at a time (SMMLA instruction operates at 2x8 * 8x2 --> 2x2). I came up with this logic while trying to understand the current algo and make the changes as isolated as possible :) sure, will check how best I can merge this logic with the dot kernel matmul loop. |
c5c9140
to
f434e52
Compare
Hi @ggerganov , I have extended the |
f434e52
to
99b811d
Compare
some of the unit tests are being invoked only from https://github.com/ggerganov/llama.cpp/actions/runs/7618043882/job/20748830174?pr=4966 |
99b811d
to
d228130
Compare
I updated the PR for windows failures but haven't tested it on Windows machine yet. it will be great if it can be tested in CI runs otherwise I will pickup local windows testing effort. |
thanks, looks like there are more places to take care on Windows. I will check. |
Looks like MSVC doesn't support VLA as well. I have fixed all the windows failures and tested on windows and Linux platforms. Next I'm looking at the below error from
|
The swift CI tests use a pinned version of ggml, so it fails every time that there are changes to |
thank you! I see those builds are using the llama.cpp package, not building from sources.
|
d228130
to
9eaba38
Compare
updated the PR to fix windows builds and ran unit tests on Ubuntu and windows. It is ready for the final review and CI. |
perplexity test results:
|
Does this perplexity value match well with what you get on For how many number of threads do you observe optimal performance with these kernels? |
Yes, the perplexity matches to the master without these changes (logs below).
|
I haven't collected data for different thread configs yet, but in general I see these gemm kernels scale with the number of threads, though not linearly.
|
The expectation is that for prompt processing the speed should always increase with increasing the number of threads, while for text-generation there should be an optimal number of threads after which the performance will start degrading. Lines 222 to 223 in 6fea843
These can be configured with the I just now realized that the new kernels are used only for prompt-processing, so that's fine |
ggml-quants.h
Outdated
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void ** restrict vx, const void ** restrict vy, const int nrc); | ||
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void ** restrict vx, const void ** restrict vy, const int nrc); | ||
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void ** restrict vx, const void ** restrict vy, const int nrc); | ||
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void ** restrict vx, const void ** restrict vy, const int nrc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not convinced if this API is desirable - it requires preparing arrays of pointers which seems quite cumbersome
Normally, linear algebra libraries utilize an API of a pointer, number of elements and stride (in bytes or in elements). So I'm thinking that we should probably switch to something like:
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc);
Note that I'm mostly thinking out loud - not yet sure what is the best way.
It's a big change so we have to consider the options to make this less intrusive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. as I mentioned earlier I tried to fit it into the existing interface itself but had changed to array mainly to consider stride. if it's better to add few more ags than arrays, how about we define a tensor attribute structure and pass it across instead of adding one arg for each attribute? this way we can extend it is future for any new functionality. For now the tensor object can just have elements, stride, format type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, adding a new tensor attribute structure would again introduce a lot of boilerplate around calling the dot functions. Adding extra arguments is better in this regard, because we already have the strides from the struct ggml_tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, ggml
stores the strides in number of bytes. So the numbers in ggml_tensor->nb
are strides in bytes. The dot functions should also accept the row strides in bytes for consistency.
In the future, we will transition to storing the strides in number of elements: ggerganov/ggml#623. But this is not important for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ggerganov , I have updated the PR. please review and let me know if it can be improved further, especially around the stride calculations. I was able to use the ggml_tensor strides (nb) for src0 and dst tensors, but, I had to arrive at the src1_col stride following the logic used for offset calculations.
9eaba38
to
d0b014f
Compare
ggml-quants.c
Outdated
vst1_f32(s, vget_low_f32(sumv2)); | ||
vst1_f32(s + 16, vget_high_f32(sumv2)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we should add a stride argument for s
too. This 16
offset is very obscure, but on the other hand the function signature would become a bit overloaded.
It's probably better to add it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ggerganov , addressed all the comments.
ff67775
to
4c840fd
Compare
I think this should be good to merge. Want to take some time to do some AWS Graviton tests first and confirm the results. If anyone else gives this a try, please post some feedback as well |
@ggerganov or anyone trying this PR, please make sure you use the instances from AWS Graviton3 family, c7g/m7g/r7g (Graviton2 doesn't support MMLA instructions). |
@ggerganov I tried this PR on an AWS Graviton3 instance. I can confirm that I observed a similar speedup as mentioned by the author of this patch. Please find below the tokens/s numbers. |
armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q8_0_q8_0 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel.
armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q4_0_q8_0 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel.
armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q4_1_q8_1 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel.
4c840fd
to
d8f132d
Compare
* ggml: aarch64: implement smmla kernel for q8_0_q8_0 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q8_0_q8_0 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: aarch64: implement smmla kernel for q4_0_q8_0 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q4_0_q8_0 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: aarch64: implement smmla kernel for q4_1_q8_1 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q4_1_q8_1 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: update unit tests for the new vec_dot interface * llama.cpp: add MATMUL_INT8 capability to system_info
* ggml: aarch64: implement smmla kernel for q8_0_q8_0 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q8_0_q8_0 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: aarch64: implement smmla kernel for q4_0_q8_0 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q4_0_q8_0 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: aarch64: implement smmla kernel for q4_1_q8_1 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q4_1_q8_1 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: update unit tests for the new vec_dot interface * llama.cpp: add MATMUL_INT8 capability to system_info
armv8.2-a and above supports MMLA instructions that have better throughput than DOT. this PR adds support for mmla kernels for
q8_0_q8_0
q4_0_q8_0
and q4_1_q8_1 quantized gemm routines.
The feature is enabled if the platform supports
__ARM_FEATURE_MATMUL_INT8
on AWS Graviton3 processors these kernels resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel.