-
Notifications
You must be signed in to change notification settings - Fork 10.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rebalancing Metal threads workload in dot product kernel kernel_mul_mv_f16_f32_l4 #7522
base: master
Are you sure you want to change the base?
Conversation
Most of the time, kernel_mul_mv_f16_f32_l4 is called to perform 4 FP ops per thread. Added kernel_mul_mv_f16_f32_l4_large which performs 128 FP ops per thread, when there are 32x less threads.
…l4_large replaced call to kernel_mul_mv_f16_f32_l4 with kernel_mul_mv_f16_f32_l4_large for vectors larger than 128 elements.
The following command generates garbage: make -j && ./main -m ./models/mistral-7b-v0.2/ggml-model-fp16.gguf -p "I believe the meaning of life is" -n 64 -s 2 -ngl 99 --temp 0 -t 4
<s> I believe the meaning of life is to work▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅ Here is a possible fix: diff --git a/ggml-metal.m b/ggml-metal.m
index 3b525071..7a758fb2 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1574,6 +1574,8 @@ static enum ggml_status ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil;
+ bool is_large = false;
+
// use custom matrix x vector kernel
switch (src0t) {
case GGML_TYPE_F32:
@@ -1592,6 +1594,7 @@ static enum ggml_status ggml_metal_graph_compute(
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
if (ne01 > 128) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4_LARGE].pipeline;
+ is_large = true;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
}
@@ -1784,7 +1787,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
const int64_t ny = (ne11 + nrows - 1)/nrows;
- if (ne01 > 128) {
+ if (is_large) {
[encoder dispatchThreadgroups:MTLSizeMake(ne01/32, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; |
This change is not a clear cut. For example for Mistral 7B where the head size is 128 there is indeed performance improvement: make -j llama-bench && ./scripts/compare-commits.sh master pr/7522 -m models/mistral-7b-v0.2/ggml-model-fp16.gguf -t 4 -p 0 -n 0 -pg 512,128 -pg 1024,128 -pg 2048,128
However, for Gemma 2B where the head size is 256, there is significant regression: make -j llama-bench && ./scripts/compare-commits.sh master pr/7522 -m models/gemma-2b/ggml-model-f16.gguf -t 4 -p 0 -n 0 -pg 512,128 -pg 1024,128 -pg 2048,128
For Phi-3 where the head size is 96, there is no difference between this PR and make -j llama-bench && ./scripts/compare-commits.sh master pr/7522 -m models/phi-3-mini-128k-instruct/ggml-model-f16.gguf -t 4 -p 0 -n 0 -pg 512,128 -pg 2048,128 -pg 8192,128 -pg 32768,128
These tests are all with disabled Flash Attention. If we enable Flash Attention, then this kernel is never executed, and the performance is universally better thanks to the more efficient attention computation via FA |
large dot product kernel selection is now consistent
With flash attention on, kernel_flash_attn_ext_vec_f16_h128 is where it spends ~70% of time in large context eval. It is very far from being limited by memory throughput, and GPU ALUs average utilization is 28%. Actually, the utilization is 35%, but every 5th GPU core is just sitting idle. The kernel is quite complex so would take me some time to understand how to improve GPU utilization, but dealing with idle GPUs should be easier and would could get up to ~12% perf improvement. Unlike MacBook Pro, when benchmarking different LLMs on Mac Studio I see consistently better performance with Flash attention comparing to this dot product fix. |
I guess it's because I'm developing on Mac Studio, so the kernel performs better in that case. It goes back to the same problem described in #6089 (comment) - not sure what is the proper way to write the Metal kernels so that they perform optimally on all chips and models |
I profiled flash attention kernel on Mac Studio, and while it is still significantly faster than the original and the attention with the fix, it still consistently leaving 20% of GPU cores on every Mac type I tested it on completely idle, and slightly underutilizing the 80% of GPU cores it is running on. So I no longer see benefit of including this patch: it is significantly faster than original attention, but is somewhat slower than flash attention in most configs, except few corner cases end users would probably not care about. So I'll check how to improve flash attention Metal kernel performance now. |
Yes, we can potentially benefit a lot from optimizing the Flash Attention kernels. Also one big limitation is that the Head Size = 256 kernels run out of registers, so as of now they are disabled. This means that models like Gemma that use HS = 256 cannot run with Flash Attention enabled |
This is platform dependent, newer GPUs do not have this constraint any longer (using cache as registers, then just spill to memory). I'll check it too, found your fix 3 days ago. |
Yes, it does work on M2 Ultra, maybe thanks to the new mechanism to spill into memory, but it is very slow |
Spilling mechanism becomes pretty efficient with M3 GPU. |
It differs between models, but the issue # one is that for most I tried, the flash attention kernel starts only 32 thread groups, and thread groups are statically scheduled to cores, so for Macs with more than 32 GPU cores (so most Macs) some cores are just idle. So for 40 cores GPU optimization potential is ~1.25x. I'll try to see how to move work from threads to thread groups to improve GPU cores utilizations. The issue # 2 is that as the kernel is long and complex, so the end ALU utilization for active GPU cores is less than 35%, and memory throughput is not a limiting factor. I suspect with clever optimizations there is another up to ~1.6x improvement potential here, will try to see if I can do something after core utilization is fixed. |
Hi @izard , Best |
This pull request is related to issue #6089. When profiling Metal implementation for large (16k+) token size prompts, I found that most of the time is spent in kernel_mul_mv_f16_f32_l4 Metal kernel. During this time GPU ALUs utilization is 7%, because current implementation fires as many threads as there are tokens, and each thread only performs 4 FP operations (plus reduction), so GPU is mostly starting and stopping threads. This applies to non-batched generation, when adding batching utilization goes up.
This change makes it spawn 32x less threads, and each thread to perform 32x more operations. This brings GPU ALUs utilization to 99%, and provides significant performance improvement for generation speeds for large contexts.
For 16384 context, I measured 1.3x improvement on M2 Max, and for 96k context I measured 1.8x improvement on M2 Max and 2.4x improvement on M3 Max.
For small context (less than 1k) I measure the same or slightly worse performance, and to avoid it the kernel selector lines
if (ne01 > 128) {
could be replaced e.g. with
if (ne01 > 8192) {