-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable flash attention for gemma #1454
Conversation
All flash related flag combinations with batch size 1 were passed. But batch size 8 with flash_attention + causal_mask cases generate junk. Need to investigate why it happen multiple batch scenario. |
29d4814
to
14c442e
Compare
Add missing flag handling to gemma --reuse_cache --use_flash_attention --flash_attention_recompute --flash_attention_causal_mask
14c442e
to
5a2ee0e
Compare
@atakaha , is the PR ready for review yet or waiting on something? |
In the point of missing flags I/F is fixed and confirmed output quality and a little memory usage improvement for BF16 single and multi cards with the flags. |
f9bad35
to
5a2ee0e
Compare
@atakaha can you paste commands and outputs(throughput, text) for 1 and 8 HPU w/ bf16 and fp8 with these changes as Thanaji had requested? |
1 HPU sanity testing at my end:
|
FP8 is same quality on my side. And FP8 with flash attention drops throughput.
|
@regisss could you take a look. Pending issue (FP8 with flash attention drops throughput.)has ticket filed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
For FP8, we need to use |
@regisss , Please review this PR. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Add missing flash attention flags to gemma
What does this PR do?
Fixes # (issue)
Before submitting