Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for 256 head dim #67

Closed
Sanger2000 opened this issue Nov 1, 2022 · 10 comments
Closed

Support for 256 head dim #67

Sanger2000 opened this issue Nov 1, 2022 · 10 comments

Comments

@Sanger2000
Copy link

Sanger2000 commented Nov 1, 2022

Really love this repo, I've been using it to finetune CodeGen models with >2k context windows.

It's way faster than hugging face (3x) and slightly faster than Megatron for the 350M and 2.7b parameter CodeGen models but doesn't work for the 6.1B and 16B parameter models as they have a head dimension of 256.

Screen Shot 2022-11-01 at 5 32 47 PM

I would imagine CodeGen finetuning will be a solid use-case for flash attention since coding models can really benefit from long context windows. And CodeGen is basically SOTA for coding (competitive with Codex).

Is this something that is even possible with flash attention?

@tridao
Copy link
Member

tridao commented Nov 1, 2022

Thanks for the kind words! I agree that code generation is a great use case!

We haven't been able to get speedup on headdim=256. In order to reduce memory reads/writes, we load some block of Q, K, V from GPU memory to SRAM, and SRAM size is the main constraint (e.g. 163 KB per streaming multiprocessor on A100). As head dimension gets large, we can't fit the block into SRAM without making the block size very small, and thus making the whole computation slower.

For now we support head dimension up to 128.

@Sanger2000
Copy link
Author

Ahh, yeah I see. Looks like I'll have to find another pretrained code model to use for long sequences 😔

Again this code has been fantastic. My mind was blown when I saw how high I could set batch sizes without running out of memory.

@tridao tridao closed this as completed Nov 5, 2022
@ZhongYingMatrix
Copy link

(e.g. 163 KB per streaming multiprocessor on A100)
Hi, is SRAM on A100 163KB or 192KB. I notice the latter in paper

@tridao
Copy link
Member

tridao commented Nov 9, 2022

The total SRAM per multiprocessor is 192KB, but only 163KB is usable for the programmer (the remaining is L1 cache I believe).

@ZhongYingMatrix
Copy link

Thanks for your reply~
So can we assume that M in the paper is 163 * 1024 = 166912, and B_c = M/4/d (assume d=64) = 652 while B_r = 64?
I'm kind of confused cause the B_c and B_r are both fixed 128 in the triton version, and I have not figured out where the Cuda version config them. Is it alright only if the SRAM can hold it no matter how we set B_c/B_r?

@tridao
Copy link
Member

tridao commented Nov 9, 2022

In practice the block sizes are set to optimize for speed and ease of implementation. For example, the Triton version set block sizes to (128, 128) because that's what the Triton compiler support (other shapes will lead to wrong results or compiler error). As another example, block sizes are always powers of 2, otherwise it's much harder to implement.
The CUDA version sets B_c = 128 or 256, and B_r = 16.
For example, the dispatch here corresponds to B_c = 256 and B_r = 16.

@shijie-wu
Copy link

Out of curiosity, what does the performance trade off curve looks like for head dim > 128? I understand for 256 the trade off is not worth it but will slightly bigger head dim be worth supporting?

@tridao
Copy link
Member

tridao commented Dec 6, 2022

I have not implemented anything for headdim > 128 (it takes effort to implement the low-level shared memory loading/writing for different head dimensions, for best efficiency).
If you have info on this that would be very helpful.

@shijie-wu
Copy link

Thanks! I don't have any info on this as well. I assumed that it's not supported because you have tried it and the result is mixed. We will explore it when we get a chance.

@tridao
Copy link
Member

tridao commented Aug 14, 2023

As of v2 we support all head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants