Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU Flash Attention improvements #172

Merged
merged 10 commits into from
Jan 15, 2025
Merged

CPU Flash Attention improvements #172

merged 10 commits into from
Jan 15, 2025

Conversation

ikawrakow
Copy link
Owner

This PR

  • Improves FA CPU performance for long contexts
  • Fixes K-cache quantized to Q8_0 when not using FA. This was broken because online Q8_0 quantization packed quants into blocks of 128 (block_q8_0_x4), so K*Q became garbage when using Q8_0 quantized K-cache without FA.

FA performance improvements are for AVX2/Zen4. The following table shows PP-512 comparison between the main branch and this PR with FA using bf16 or Q8_0 for KV cache. Model is LLaMA-3.1-8B quantized to IQ4_XS and run-time-repacked to IQ4_XS_R4. The CPU is Ryzen 7950X. When the quoted uncertainty in the table is zero, I have run just a single repetition in llama-bench (it takes quite a while to process 16k or even 32k tokens)

type_k type_v fa rtr test t/s (main) t/s (pr) Speedup
bf16 bf16 1 1 pp128 275.27 ± 1.63 278.40 ± 1.60 1.011
bf16 bf16 1 1 pp256 276.16 ± 3.46 283.51 ± 1.22 1.027
bf16 bf16 1 1 pp512 274.71 ± 0.51 276.83 ± 0.36 1.008
bf16 bf16 1 1 pp1024 265.81 ± 1.65 270.05 ± 0.41 1.016
bf16 bf16 1 1 pp2048 256.95 ± 0.39 260.11 ± 0.14 1.012
bf16 bf16 1 1 pp4096 237.97 ± 0.37 242.29 ± 0.75 1.018
bf16 bf16 1 1 pp8192 206.34 ± 1.25 213.98 ± 0.35 1.037
bf16 bf16 1 1 pp16384 156.40 ± 0.00 173.44 ± 0.00 1.109
bf16 bf16 1 1 pp32768 82.97 ± 0.00 122.47 ± 0.00 1.476
q8_0 q8_0 1 1 pp128 273.44 ± 1.04 279.27 ± 1.43 1.021
q8_0 q8_0 1 1 pp256 278.57 ± 1.03 283.00 ± 0.63 1.016
q8_0 q8_0 1 1 pp512 271.56 ± 0.05 275.97 ± 0.79 1.016
q8_0 q8_0 1 1 pp1024 264.31 ± 0.89 269.35 ± 0.33 1.019
q8_0 q8_0 1 1 pp2048 253.70 ± 0.24 258.22 ± 0.36 1.018
q8_0 q8_0 1 1 pp4096 232.07 ± 0.88 236.83 ± 1.38 1.021
q8_0 q8_0 1 1 pp8192 199.90 ± 1.37 204.74 ± 0.34 1.024
q8_0 q8_0 1 1 pp16384 153.62 ± 0.00 164.50 ± 0.00 1.071
q8_0 q8_0 1 1 pp32768 103.48 ± 0.00 113.35 ± 0.00 1.095

~2-3% sort of thing. Sadly, when we go beyond 8k tokens, the
advantage kind of goes away.
E.g., -ctk q8_0 -ctv q8_0 is slightly faster than
-ctk q8_0 -ctv q8_0 on Zen4 for not too long context lengths
(say, <= 4096).
We now hit 122 t/s for LLaMA-3.1-8B (quantized as iq4_xs and
run-time-repacked) with a context of 32768. IIRC, the previous
best for such large context was ~90 t/s.
Non-negligible improvement at 16384 and 8192 as well:
173.4 and 214 t/s.
E.g., for q8_0 and context of 32768, we are now at 113 t/s
for LLaMA-3.1-8B.

Also simplified the quantized K*Q multiplication.
1. We add new types GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4, and use
   those to quantize activations for quants that use Q8_0 or Q8_1
   as their vec_dot type.
2. We revert the changes to quantize_row_q8_0 and quantize_row_q8_1
3. We use GGML_TYPE_Q8_0_X4 and GGML_TYPE_Q8_1_X4 as the vec_dot type
4. We change the FA implementation to use GGML_TYPE_Q8_0 rather than
   GGML_TYPE_Q8_0_X4 as the K and V types
5. We change the expected type to GGML_TYPE_Q8_0_X4/GGML_TYPE_Q8_1_X4
   in iqk_mul_mat

Also added an optimization in ggml_compute_forward_mul_mat when
ne12*ne13 > 1 (K*Q and V*softmax(K*Q)) to process
n12*ne13/GCD(n12*ne13, nthread) threads simultaneously using
nthread/GCD(n12*ne13, nthread) threads per head. This results in
a non-negligible performance gain for large contexts.

Question: why is it not allowed to use quantized V-cache when
not using FA?
Again the issue with _mm256_maddubs_epi16 overflowing that I
keep forgetting.
@ikawrakow ikawrakow merged commit 0b74397 into main Jan 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants