You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I reviewed the Discussions, and have a new bug or useful enhancement to share.
Feature Description
This is more a question if anyone is/was looking at the following long context performance issue in Metal. I did not find anything in the repo history, but maybe I just missed it.
When profiling long contexts (starting from about ~25K tokens), I found that block processing latency started being dominated by kernel_mul_mv_f16_f32_l4(width=small, like 128, height=large, like 32768 for context length slightly smaller than 32k, input vector length=32768). Running this kernel takes ~80% of total time, and this runtime is dominated by very low GPU execution units utilization caused by 32768 threads, each running very small chunk of work. There is no memory pressure.
So the code
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
and kernel can be somehow optimized not to start number of threads == vector length, but chunk the work differently
By the way, in this matrix to vector f16_f32_l4 kernel, lines
for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 ) (src1 + r1nb11 + im*nb12);
..
}
are redundant as it is always called with nrows == 1, so when I replaced it with just
device const float4 * y4 = (device const float4 ) (src1 + imnb12);
nothing has changed (except the code became a tiny bit cleaner).
Motivation
Significant performance drop for long context prompts in Metal is caused by inefficient Metal threads scheduling, once fixed, I expect smaller time increase for longer contexts.
E.g. that is what I measured for one of the common models running on M3 Max:
context:323, t/s: 7.2
context:2248, t/s: 6.3
context:5908, t/s: 5.1
context:10314, t/s: 4.4
context:15112, t/s: 3.65
context:20556, t/s: 3.1
context:24588, t/s: 3
Possible Implementation
Change kernel_mul_mv_f16_f32_l4, or possibly add kernel_mul_mv_f16_f32_l4_long_vector with different threadGroups and threadsPerThreadgroup thread/blocks count
The text was updated successfully, but these errors were encountered:
Thank you for looking into this! Any help with optimizing the Metal kernels would be appreciated. I myself don't even know how to use the profiling tools, so it's possible that there is a lot of room for optimizations
for (int r1 = 0; r1 < nrows; ++r1) {
nrows is 1 only for batch size == 1. It can be larger than 1 when we allow larger batch sizes to use mat-vec kernels:
Prerequisites
Please answer the following questions for yourself before submitting an issue.
Feature Description
This is more a question if anyone is/was looking at the following long context performance issue in Metal. I did not find anything in the repo history, but maybe I just missed it.
When profiling long contexts (starting from about ~25K tokens), I found that block processing latency started being dominated by kernel_mul_mv_f16_f32_l4(width=small, like 128, height=large, like 32768 for context length slightly smaller than 32k, input vector length=32768). Running this kernel takes ~80% of total time, and this runtime is dominated by very low GPU execution units utilization caused by 32768 threads, each running very small chunk of work. There is no memory pressure.
So the code
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
and kernel can be somehow optimized not to start number of threads == vector length, but chunk the work differently
By the way, in this matrix to vector f16_f32_l4 kernel, lines
for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 ) (src1 + r1nb11 + im*nb12);
..
}
are redundant as it is always called with nrows == 1, so when I replaced it with just
device const float4 * y4 = (device const float4 ) (src1 + imnb12);
nothing has changed (except the code became a tiny bit cleaner).
Motivation
Significant performance drop for long context prompts in Metal is caused by inefficient Metal threads scheduling, once fixed, I expect smaller time increase for longer contexts.
E.g. that is what I measured for one of the common models running on M3 Max:
context:323, t/s: 7.2
context:2248, t/s: 6.3
context:5908, t/s: 5.1
context:10314, t/s: 4.4
context:15112, t/s: 3.65
context:20556, t/s: 3.1
context:24588, t/s: 3
Possible Implementation
Change kernel_mul_mv_f16_f32_l4, or possibly add kernel_mul_mv_f16_f32_l4_long_vector with different threadGroups and threadsPerThreadgroup thread/blocks count
The text was updated successfully, but these errors were encountered: