Replies: 6 comments 5 replies
-
Have you tried these measurements with the latest llamafile sources? There's a variety of improvements to thread synchronization. For example, here's a better memory barrier that's more on par with what GNU OpenMP does. void ggml_barrier(const struct ggml_compute_params * params) {
if (params->shared->n_threads == 1)
return;
int n = params->shared->n_threads;
atomic_int * count = ¶ms->shared->n_barrier;
atomic_uint * phase = ¶ms->shared->n_barrier_passed[params->ith].i;
unsigned i = atomic_load_explicit(phase, memory_order_relaxed);
if (atomic_fetch_add_explicit(count, 1, memory_order_acq_rel) == n - 1) {
atomic_store_explicit(count, 0, memory_order_relaxed);
for (int j = 0; j < n; ++j)
atomic_store_explicit(¶ms->shared->n_barrier_passed[j].i,
i + 1, memory_order_relaxed);
atomic_thread_fence(memory_order_release);
} else {
while (atomic_load_explicit(phase, memory_order_relaxed) == i)
pthread_pause_np();
atomic_thread_fence(memory_order_acquire);
}
} In for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
struct ggml_tensor * node = cgraph->nodes[node_n];
if (ggml_is_noop(node->op)) // [jart]
continue;
// ... Assuming you have this defined: static bool ggml_is_noop(enum ggml_op op) { // [jart]
switch (op) {
case GGML_OP_NONE:
case GGML_OP_PERMUTE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
case GGML_OP_VIEW:
return true;
default:
return false;
}
} llama.cpp also likes to spawn a thread for every token when predicting. You can make threads spawn/join 10x faster with this: Is this all something that'd interest you? I can easily send a PR adding it to your repo if you don't care about things like MSVC. |
Beta Was this translation helpful? Give feedback.
-
Hey @jart, thanks for the comments!
No, I'm working with my
Ha, you had already done that! I didn't check
I don't care about MSVC, so sure. There is the MIT vs Apache-2.0 issue, but we can sort that out. |
Beta Was this translation helpful? Give feedback.
-
I did try a few things on this branch, but nothing is really working. The branch is just exploratory, absolutely not production ready, and
On the bright side, PR #27 merges "soft-capping" with soft-max. For large prompts, this leads to a significant performance boost for Gemma-2 models. At 32k tokens and Gemma-2-2b, the performance gap between GPU with flash attention and the Ryzen-7950X CPU is now "only" a factor of 45 (instead of the 53X in the above graph). |
Beta Was this translation helpful? Give feedback.
-
OK, I have progress on this branch. Extremely hacky and This graph shows the current status. y-axis is tokens per second on my Ryzen-7950X CPU, x-axis is context size (logarithmic scale). Black symbols show the performance in this repository, green is mainline My guess is that there is still a bottleneck at 32k tokens. Based on the FA to n-FA relative performance increase up to 16k tokens I would expect a performance gain above 30% at 32k tokens instead of the 23% we currently get. |
Beta Was this translation helpful? Give feedback.
-
And here is how the raltive CPU vs GPU performance graph changes with the new CPU flash attention implementation. The FA curve is basically flat now beyond 1000 tokens, except at 32k where I suspect a bottleneck that I have not found. |
Beta Was this translation helpful? Give feedback.
-
There has been progress since I last wrote here, with PR #172 being the latest contribution to improving CPU prompt processing speed. The following graph is for LLaMA-3.1-8B-Instruct quantized to
It is also interesting to look at the performance relative to a GPU. I'm using an RTX-4080 GPU with the same model and FA enabled. Compared to earlier plots in this thread, I have changed the plot to show the ratio of GPU to CPU prompt processing speed and have restricted the prompt length to |
Beta Was this translation helpful? Give feedback.
-
Back in the day when open source / open weight LLMs had a very limited context window, one of the most desired features among LLM enthusiasts was a larger context window. People came up with all sorts of modifications to the RoPE operation, used (LoRA) fine tuning, etc., to increase the context window beyond the maximum context used during model training. Today we have open source / open weight models that can handle much longer contexts. E.g., LLaMA-3.1 goes up to 128k tokens, which is probably more than what one can handle with consumer grade hardware for "Inference at the Edge" (and I find it kind of funny to see the many issues opened in the
llama.cpp
repository because users did not limit the maximum context length when runningllama.cpp
, and correspondingly the model would not load because the KV-cache required for 128k tokens does not fit into their <= 24 GB VRAM).But how well is the large context length being handled?
On the GPU
llama.cpp
has an implementation of Flash Attention (FA), which improves prompt processing speeds for long contexts quite a bit (see the graph below). But, as mentioned, one cannot take advantage of the full context offered by LLaMA-3.1 - me for instance, with the paltry 16 GB VRAM on the RTX-4080 that I have at my disposal, cannot go beyond 32k tokens even for 8B LLaMA-3.1.llama.cpp
has a FA implementation for the CPU as well, so let's see how well this works:which gives these results on my Ryzen-7950X CPU:
Oops. FA is slower than no-FA. This is mainline
llama.cpp
. What about the version in this repository where we have much improved CPU prompt processing speed? We get this:Oops. Even worse - FA is 26% slower. Why? Because when FA is turned on the
KQ = K * Q
andKQV = V * KQ
matrix multiplications are handled internally within the FA kernel, so no longer take advantage of the optimized version provided byiqk_mul_mat
, so performance suffers more.So, the short answer is: no luck with the current
llama.cpp
version using long contexts on the CPU (unless of course one is very patient).Anyhow, how well does the CPU do compared to the GPU? The following graph shows the ratio of tokens/second on the CPU to tokens/second on the GPU as a function of prompt length. The CPU is Ryzen-7950X, the GPU is RTX-4080. The black symbols/line is the ratio without GPU Flash Attention, the red circles/line is with FA turned on on the GPU (but not on the CPU).
The behavior of the curves is interesting for relatively short prompts (say, up to 32 tokens, which is the range of interest for speculative sampling or batch processing), but here we are interested in the portion beyond 500 tokens. Without FA on the GPU, the CPU does improve relative to the GPU with increasing context length, becoming only 16X slower at 32k tokens ("only" considering that we are comparing a $500 previous generation Ryzen to the second fastest consumer grade GPU currently on the market). But when FA is turned on, the performance gap keeps increasing with increasing context length, reaching about 53X slower than the GPU at 32k tokens (and hence the GPU with FA is 3.1X faster compared to no-FA at 32k tokens).
Clearly it would be useful if we could make the CPU go faster for large contexts.
Here is a quick summary of how the computation time is spent on the CPU when processing a prompt of 32k tokens (using LLaMA-3.1-8B quantized to
Q4_K_S
). For comparison, I have added in the 4th column the fraction of time spent for the various operations in the more "normal" case of processing 512 tokens.So, basically the entire time is spent doing matrix multiplications and
SOFT_MAX
on theK*Q
product in the self-attention part (but according to the measured wall time the operation took 495 seconds, while the total of all operations works out to 472 seconds, so there is possibly a ~5% spent on thread synchronization).SOFT_MAX
, which takes less than 1% of the processing time for 512 tokens increases to 17.8% for a context of 32k. But why isSOFT_MAX
taking so long? Didn't Justine Tunney just recently contribute a vectorizedexpf
implementation tollama.cpp
, hich should makeSOFT_MAX
go faster? Well, the vectorizedexpf
is being used here, but we also need to load from/store back to RAM 2080 GiB while computingSOFT_MAX
. Given the 84.1 seconds taken bySOFT_MAX
, this works out to about 25 GiB/s, which is pretty close to the 30 GiB/s the Ryzen-7950X CPU can do in the best case scenario when copying data from here to there.What about the matrix multiplications? The next table shows total time in us and the fraction of the total matrix multiplication time time taken by the various matrix multiplications (note: this is the sum over all layers):
So, close to 60% of the matrix multiplication time is spent for
kq = K*Q
andkqv = V * softmax(K*Q)
. Combining 60% of 80% with 17.8% forSOFT_MAX
, we have close to 2/3 of the total time being spent onK*Q
,softmax(K*Q)
andV*softmax(K*Q)
. Interestingly enough, thekq
andkqv
matrix multiplications require the exact same amount of floating point operations - 142.94 TFLOP for the 32k context we are looking at. And yet,kqv
is computed about 35% faster - why? Again, it is a matter of storing data to RAM:kq
is 2080 GiB (no, we don't keep it all, processing is done in batches), so this works out to 16.1 GiB/s written to memory while computingkq
. On the other handkqv
is "just" 16 GiB, so the matrix multiplication function is storing results at a rate of 0.17 GiB/s - so it is far from being throttled by memory bandwidth. We also see from the data that we get about 1.5 TFLOP/s when computingkqv
, and about 1.1 TFLOP/s forkq
. I happen to know that in a synthetic benchmark with just matrix multiplications and result fitting into L2 cache, we get about 2 TFLOP/s with theiqk_mul_mat
implementation forfp32
.Based on this, here are some angles of attack for improving the CPU performance for large prompts:
kqv
speed closer to the 2 TFLOP/s we know is achievablekq
performance by better interleaving computation with memory writes. We are at ~16 GiB/s and 30 GiB/s is the limit on this CPUkq
andsoftmax(kq)
into a single operation. As I don't want to go implement this new operation on all back-ends, the fusing should be done on-the-fly while evaluating the computation graph on the CPU. This will eliminate writingkq
to RAM, so has the potential of shaving off at least 15% of the timeK*Q
,softmax(K*Q)
andV*softmax(K*Q)
into a single operation. I.e., re-discover Flash Attention :-) As the experience with thellama.cpp
CPU implementation shows, it is not just a matter of not storing intermediate results into RAM. One still needs to go as fast as possible with the matrix multiplications to actually get performance improvement from this.fp32
- we get in the 2.5 to 3 TFLOP/s with the implementation iniqk_mul_mat
, but I need to look in more detail into the associated accuracy loss. In addition, ifV
is quantized,softmax(K*Q)
must be quantized as well, which may be too costly unless fused into thesoftmax(K*Q)
operation.Beta Was this translation helpful? Give feedback.
All reactions