Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zen4 Flash Attention #32

Merged
merged 4 commits into from
Sep 1, 2024
Merged

Zen4 Flash Attention #32

merged 4 commits into from
Sep 1, 2024

Conversation

ikawrakow
Copy link
Owner

@ikawrakow ikawrakow commented Sep 1, 2024

TL;DR

This PR adds a flash attention (FA) implementation optimized for the Zen4 architecture as part of the quest to improve CPU inference for long contexts (#25, #26).

Limitations

  • It is Zen4-only for now. Strictly speaking, a much smaller subset of the AVX512 specification is required in the implementation (just AVX512F and AVX512DQ) compared to what Zen4 provides, but I didn't want to have too many variants, so decided to enable for Zen4 only.
  • It is not implemented for ALiBi or unmasked attention. It is trivial to add these but I didn't want to clutter the implementation with branches that are mostly irrelevant.

Performance comparisons

The following graph compares the prompt processing (PP) performance of mainline llama.cpp (build: a47667cf - 3650) without (green symbols) and with (blue symbols) FA to PP performance in this repository for Q4_K_S-quantized LLaMA-3.1-8B running on a Ryzen-7950X CPU where

  • Black symbols are without FA
  • Brown symbols are with FA inherited from llama.cpp
  • Magenta symbols are with the new FA implementation in this PR

fa

We observe that the original FA implementation results in a significant performance degradation in mainline llama.cpp and also here. The effect is much stronger for the version here. This is due to the K*Q and V*(softmax(K*Q) matrix multiplications being much faster in this repository thanks to iqk_mul_mat, so performance hit is larger when they are replaced with the original llama.cpp FA CPU kernel. The new FA implementation improves performance. The improvement increases with context length, reaching about 24% at 32k tokens.

The next graph shows results for Q4_K_S-quantized Gemma-2-2b. Symbol colors are the same as above.

fa_gemma2b

In this case the original FA kernel improves performance in mainline llama.cpp. The difference in behavior compared to LLaMA-3.1-8B is easily explained by the fact that the Gemma-2 series of models use "soft-caping" in their attention layers, where softcap(x) = c * tanh(x/c) (c is a model-defined constant). This is implemented as 3 different operations in llama.cpp. When FA is enabled, these 3 operations, along with softmax are fused into a single kernel, and this results in am improvement of mainline llama.cpp performance even for short contexts. But when the original FA kernel is used in our version, where "soft-caping" is already handled by a dedicated fused operation, we get a massive drop in performance just like in the LLaMA-3.1-8B case above. The new implementation in this PR is much better and performance improves again, reaching 11% at 8k tokens, which is the maximum training context length of Gemma-2-2b.

@ikawrakow ikawrakow merged commit dc023bc into main Sep 1, 2024
@ikawrakow ikawrakow mentioned this pull request Sep 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants