Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
TL;DR
This PR adds a flash attention (FA) implementation optimized for the Zen4 architecture as part of the quest to improve CPU inference for long contexts (#25, #26).
Limitations
AVX512F
andAVX512DQ
) compared to what Zen4 provides, but I didn't want to have too many variants, so decided to enable for Zen4 only.Performance comparisons
The following graph compares the prompt processing (PP) performance of mainline
llama.cpp
(build: a47667cf - 3650) without (green symbols) and with (blue symbols) FA to PP performance in this repository forQ4_K_S
-quantized LLaMA-3.1-8B running on a Ryzen-7950X CPU wherellama.cpp
We observe that the original FA implementation results in a significant performance degradation in mainline
llama.cpp
and also here. The effect is much stronger for the version here. This is due to theK*Q
andV*(softmax(K*Q)
matrix multiplications being much faster in this repository thanks toiqk_mul_mat
, so performance hit is larger when they are replaced with the originalllama.cpp
FA CPU kernel. The new FA implementation improves performance. The improvement increases with context length, reaching about 24% at 32k tokens.The next graph shows results for
Q4_K_S
-quantized Gemma-2-2b. Symbol colors are the same as above.In this case the original FA kernel improves performance in mainline
llama.cpp
. The difference in behavior compared to LLaMA-3.1-8B is easily explained by the fact that the Gemma-2 series of models use "soft-caping" in their attention layers, wheresoftcap(x) = c * tanh(x/c)
(c
is a model-defined constant). This is implemented as 3 different operations inllama.cpp
. When FA is enabled, these 3 operations, along withsoftmax
are fused into a single kernel, and this results in am improvement of mainlinellama.cpp
performance even for short contexts. But when the original FA kernel is used in our version, where "soft-caping" is already handled by a dedicated fused operation, we get a massive drop in performance just like in the LLaMA-3.1-8B case above. The new implementation in this PR is much better and performance improves again, reaching 11% at 8k tokens, which is the maximum training context length of Gemma-2-2b.