-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Binary KQ mask #28
Draft
ikawrakow
wants to merge
13
commits into
main
Choose a base branch
from
ik/kq_mask
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Binary KQ mask #28
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Here we get a small speedup: Gemma-2-2b and 32k context is ~4% faster on Zen4. But on Zen4 we can use _mm512_mask_mul_ps(-inifnity, mask, s_after, tanh(x*s_before)) to scale and apply mask in a single op that has the same latency and throughput as _mm512_mul_ps. Combined with reducing memory loads for the mask represented as fp32 (or fp16), this gives us some performance improvement for very large masks (contexts). It will be much more tricky on the other platforms that do not have masked instructions.
Relatively painless to implement for soft_max and soft_cap_max. We gain 11.5% for LLaMA-8B and ~14% for Gemma-2-2b at 32k tokens. The KQ mask is prepared on the CPU and copied to the GPU, so my guess is that most of it comes from the 32X reduction in the amount of data being copied to the GPU. TODO: flash attention
For now just soft_cap_max. On Gemma2-9b I'm observing a ~2% speedup for context of 16k tokens.
I need to redo this with better templates.
It is a pain to implement binary mask to 32-bit value conversion on NEON and AVX2, so I decided to make the binary mask optional There is also a commented out (and not working) attempt for NEON in this commit.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR is another attempt to improve performance for large contexts, see #25
Basically, when we want to process a very long context, the KQ mask, which is stored as
f32
(orf16
, if using flash attention), becomes quite significant in size. If running on the GPU, the cost for copying the KQ mask to the GPU (the mask is created on the host CPU) becomes non-negligible. If running on a CPU that has limited memory bandwidth (basically allx86
orx86_64
), the KQ mask may not fit in the cache, or if it does fit it reduces the cache available for other data by a significant amount, which results in a measurable impact on the performance of theSOFT_MAX
(or the new fusedSOFT_CAP_MAX
) operation. Hence, it will be desirable to reduce the size of the KQ mask.If not using ALiBi (basically almost always these days), the KQ mask stored 2 values:
0, -INFINITY
. It can therefore be represented as a binary mask, thus reducing its size by a factor of 32.This PR adds an option to use a binary KQ mask. It is off by default as not all platforms are implemented, but can be turned on using
-bkq
or--binary-kq
on the command line. This will have no effect if flash attention is used (KQ mask remainsf16
as before). If turned on but not supported by the back-end (non-AVX512
CPUs), the program will assert and terminate.I see 3-5% performance gains on CUDA and a Ryzen-7950X CPU for a context of 32k tokens, and about 2-3% on Metal for a context of 16k. So, nothing earth-shattering. and hence not quite convinced to merge it.