Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable flash attention for gemma #1454

Merged
merged 1 commit into from
Nov 15, 2024

Conversation

atakaha
Copy link
Contributor

@atakaha atakaha commented Oct 23, 2024

Add missing flash attention flags to gemma

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@atakaha atakaha requested a review from regisss as a code owner October 23, 2024 18:52
@atakaha
Copy link
Contributor Author

atakaha commented Oct 23, 2024

@tthakkal , @libinta @mandy-li , please review this PR.

@tthakkal
Copy link
Contributor

@tthakkal , @libinta @mandy-li , please review this PR.

@atakaha Have you verified for accuracy and performance with these commands added?
bf16
single card & multi card

fp8
single card and multi card

@atakaha
Copy link
Contributor Author

atakaha commented Oct 24, 2024

@tthakkal , @libinta @mandy-li , please review this PR.

@atakaha Have you verified for accuracy and performance with these commands added? bf16 single card & multi card

fp8 single card and multi card

All flash related flag combinations with batch size 1 were passed. But batch size 8 with flash_attention + causal_mask cases generate junk. Need to investigate why it happen multiple batch scenario.

@atakaha atakaha force-pushed the gemma_flash_attention branch 4 times, most recently from 29d4814 to 14c442e Compare October 30, 2024 20:33
@atakaha atakaha marked this pull request as draft October 31, 2024 01:01
Add missing flag handling to gemma
   --reuse_cache
   --use_flash_attention
   --flash_attention_recompute
   --flash_attention_causal_mask
@atakaha atakaha force-pushed the gemma_flash_attention branch from 14c442e to 5a2ee0e Compare November 2, 2024 01:12
@vidyasiv
Copy link
Contributor

vidyasiv commented Nov 4, 2024

@atakaha , is the PR ready for review yet or waiting on something?
update: i found the internal ticket and will track that. thnx

@atakaha
Copy link
Contributor Author

atakaha commented Nov 4, 2024

@atakaha , is the PR ready for review yet or waiting on something?

In the point of missing flags I/F is fixed and confirmed output quality and a little memory usage improvement for BF16 single and multi cards with the flags.
I'm observing FP8 output quality issue from original code (without this code code change). I'm not sure this is expected behavior not. If this is not expected then we need investigate and fix it.

@atakaha atakaha marked this pull request as ready for review November 4, 2024 17:58
@atakaha
Copy link
Contributor Author

atakaha commented Nov 5, 2024

@tthakkal, @vidyasiv , please review this PR.

@atakaha atakaha force-pushed the gemma_flash_attention branch from f9bad35 to 5a2ee0e Compare November 5, 2024 01:45
@vidyasiv
Copy link
Contributor

vidyasiv commented Nov 5, 2024

@atakaha can you paste commands and outputs(throughput, text) for 1 and 8 HPU w/ bf16 and fp8 with these changes as Thanaji had requested?
As mentioned on ticket perhaps you can file new ones for issues you discovered.

@vidyasiv
Copy link
Contributor

vidyasiv commented Nov 5, 2024

1 HPU sanity testing at my end:

  • bf16 (works)
  • bf16 w/ flash attention (works, improves throughput)
  • fp8 w/ flash attention (inaccurate text outputs)
  • bf16 w/ reuse_cache(works)
python run_generation.py --model_name_or_path google/gemma-7b \
--attn_softmax_bf16 --use_hpu_graphs --trim_logits --use_kv_cache --max_new_tokens 64 \
--bf16 --batch_size 8
Input/outputs:
input 1: ('DeepSpeed is a machine learning framework',)
output 1: ('DeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch, and it is compatible with the existing PyTorch ecosystem. DeepSpeed is designed to be easy to use, and it provides a number of features that make it easy to train large-scale models',)

input 2: ('He is working on',)
output 1: ('He is working on a new project, which is a sequel to his 2016 film, <em>The Legend of Michael Mishra</em>.\n\n“I am working on a sequel to <em>The Legend of Michael Mishra</em>. It is a comedy film. I am writing the script and I will start shooting for it in the',)

input 3: ('He has a',)
output 1: ('He has a very good knowledge of the market and is very professional. He is very helpful and always available to answer any questions.\n\nI would highly recommend him to anyone looking to buy or sell a property.\n\nWe were very happy with the service provided by the team at Ray White. They were very professional and knowledgeable, and they',)

input 4: ('He got all',)
output 1: ('He got all the way to the final of the 2019 edition of the show, but this year he’s back with a bang.\n\nThe 26-year-old from the Isle of Wight is a professional dancer and choreographer who has worked with the likes of Little Mix, Olly Murs and Fleur',)

input 5: ('Everyone is happy and I can',)
output 1: ('Everyone is happy and I can’t wait to see what the future holds for us.\n\nI’m so happy to have found a place that I can call home.\n\nI’m so happy to have found a place that I can call home.\n\nI’m so happy to have found a place that I can call home.\n\n',)

input 6: ('The new movie that got Oscar this year',)
output 1: ('The new movie that got Oscar this year is a movie that is based on a true story. The movie is called “The Imitation Game”. The movie is about a man named Alan Turing who was a mathematician and a code breaker. He was a very smart man and he was able to break the code that the Germans were using to communicate with each other. He',)

input 7: ('In the far far distance from our galaxy,',)
output 1: ('In the far far distance from our galaxy, there is a planet called Earth. On this planet, there are many different species of animals. One of them is the human.\n\nThe human is a very special species. They have a very high intelligence and they can create many things. They can create a lot of things that can help them to survive.\n\nOne',)

input 8: ('Peace is the only way',)
output 1: ('Peace is the only way to solve the conflict in South Sudan, the country’s President Salva Kiir has said.\n\nKiir made the remarks on Wednesday during the 10th anniversary of the Comprehensive Peace Agreement (CPA) in Juba.\n\n“The only way to solve the conflict in South Sudan is through peace. We have',)


Stats:
----------------------------------------------------------------------------------
Input tokens
Throughput (including tokenization) = 788.1560337240456 tokens/second
Memory allocated                    = 18.53 GB
Max memory allocated                = 18.66 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 2.8224462040000162 seconds
----------------------------------------------------------------------------------


python run_generation.py --model_name_or_path google/gemma-7b \
 --attn_softmax_bf16 --use_hpu_graphs --trim_logits --use_kv_cache --max_new_tokens 64 \
 --bf16 --batch_size 8 --use_flash_attention

Input/outputs:
input 1: ('DeepSpeed is a machine learning framework',)
output 1: ('DeepSpeed is a machine learning framework that enables the training of large-scale models on commodity hardware. It is designed to be flexible and extensible, allowing researchers to easily add new algorithms and optimizations to the framework. DeepSpeed is also designed to be efficient, using techniques such as data parallelism and mixed-precision training to reduce the amount of time and resources required',)

input 2: ('He is working on',)
output 1: ('He is working on a new project, which is a sequel to his 2016 film, <em>The Legend of Michael Mishra</em>.\n\n“I am working on a sequel to <em>The Legend of Michael Mishra</em>. It is a comedy film. I am writing the script and I will start shooting for it in the',)

input 3: ('He has a',)
output 1: ('He has a very good knowledge of the market and is very professional. He is very helpful and always available to answer any questions.\n\nI would highly recommend him to anyone looking to buy or sell a property.\n\nHe is very professional and knowledgeable. He was always available to answer any questions we had and made the process of buying a',)

input 4: ('He got all',)
output 1: ('He got all the way to the final of the 2019 edition of the show, but this year he’s back with a bang.\n\nThe 26-year-old from the Isle of Wight is a professional dancer and choreographer who has worked with the likes of Little Mix, Olly Murs and Fleur',)

input 5: ('Everyone is happy and I can',)
output 1: ('Everyone is happy and I can’t wait to see what the future holds for us.\n\nI’m so happy to have found a place that I can call home.\n\nI’m so happy to have found a place that I can call home.\n\nI’m so happy to have found a place that I can call home.\n\n',)

input 6: ('The new movie that got Oscar this year',)
output 1: ('The new movie that got Oscar this year is a movie that is based on a true story. The movie is called “The Imitation Game”. The movie is about a man named Alan Turing who was a mathematician and a code breaker. He was a very smart man and he was able to break the code that the Germans were using to communicate with each other. He',)

input 7: ('In the far far distance from our galaxy,',)
output 1: ('In the far far distance from our galaxy, there is a planet called Earth. On this planet, there are many different species of animals. One of them is the human.\n\nThe human is a very special species. They have a very high intelligence and they can create many things. They can create a lot of things that can help them to survive.\n\nOne',)

input 8: ('Peace is the only way',)
output 1: ('Peace is the only way to end the war in Ukraine, the Russian president, Vladimir Putin, has said, as he accused the west of trying to “dismember” his country.\n\nIn a speech to mark the 80th anniversary of the Soviet victory over Nazi Germany in the second world war, Putin said the west was trying to',)


Stats:
----------------------------------------------------------------------------------
Input tokens
Throughput (including tokenization) = 817.3428140417304 tokens/second
Memory allocated                    = 18.55 GB
Max memory allocated                = 18.72 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 2.686410730000034 seconds

QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py --model_name_or_path google/gemma-7b \
--attn_softmax_bf16 --use_hpu_graphs --trim_logits --use_kv_cache --max_new_tokens 64 \
--bf16 --batch_size 1

Input/outputs:
input 1: ('DeepSpeed is a machine learning framework',)
output 1: ('DeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch, and it is compatible with the existing PyTorch ecosystem. DeepSpeed is designed to be easy to use, and it provides a number of features that make it easy to train large-scale models',)


Stats:
-----------------------------------------------------------------------------------
Input tokens
Throughput (including tokenization) = 107.65460798022197 tokens/second
Memory allocated                    = 19.16 GB
Max memory allocated                = 20.52 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 2.4955870979999872 seconds
-----------------------------------------------------------------------------------

QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_generation.py --model_name_or_path google/gemma-7b \
--attn_softmax_bf16 --use_hpu_graphs --trim_logits --use_kv_cache --max_new_tokens 64 \
--bf16 --batch_size 8 --use_flash_attention

Input/outputs:
input 1: ('DeepSpeed is a machine learning framework',)
output 1: ('DeepSpeed is a machine learning framework that suspic suspic suspicispecially unifore unif unifore enthusi unif unif enthusi enthusi enthusi infinites enthusi enthusi infinites enthusi infinites enthusi infinites enthusi infinites infinites premia enthusi infinites enthusi infinites enthusi infinites premia enthusi infinites infinites premia premia premia premia premia premia premia premia premia premia premia premia premia premia premia premia premia premia premia premia premia premia premia premia premia',)

input 2: ('He is working on',)
output 1: ('He is working on my imago. He has my imago, my imago antem Idem, my imago. My imago imago imago imago imago imago. fepdhdhd madonna my imago. My imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago imago',)

input 3: ('He has a',)
output 1: ('He has a mysterical past',)

input 4: ('He got all',)
output 1: ('He got all the upvotes, upvotes, and upvotes, but when it came time to get his upvotes upvotes ① upvotes, ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ① ①',)

input 5: ('Everyone is happy and I can',)
output 1: ('Everyone is happy and I can’t wait for my niece’s first birthday party, my daughter’s first day of kindergarten or my son’s first day of exorbitantly profanely alphabe smartypants mef alphabe alphabe alphabe alphabe alphabe smartypants alphabe alphabe alphabe smartypants alphabe smartypants alphabe alphabe smartypants smartypants alphabe',)

input 6: ('The new movie that got Oscar this year',)
output 1: ('The new movie that got Oscar this year, The indestructibles, has alre manikul than the alphabe disadpecially disespecially alphabe alphabe alphabe alphabe alphabe encre alphabe alphabe alphabe alphabe alphabe alphabe alphabe alphabe encre alphabe encre disespecially alphabe alphabe encre dises manikul alphabe alphabe alphabe encre dises alphabe alphabe dises encre dises alphabe manikul alphabe manufact alphabe alphabe alphabe alphabe encre dises alphabe alphabe',)

input 7: ('In the far far distance from our galaxy,',)
output 1: ('In the far far distance from our galaxy, we can see that the milky way has a prominant bump',)

input 8: ('Peace is the only way',)
output 1: ('Peace is the only way, my friend,\n',)


Stats:
-----------------------------------------------------------------------------------
Input tokens
Throughput (including tokenization) = 1336.5589212865011 tokens/second
Memory allocated                    = 10.46 GB
Max memory allocated                = 11.18 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 8.916372483999794 seconds
-----------------------------------------------------------------------------------

python run_generation.py --model_name_or_path google/gemma-7b \
 --attn_softmax_bf16 --use_hpu_graphs --trim_logits --use_kv_cache --max_new_tokens 64 \
 --bf16 --batch_size 8 --reuse_cache

Input/outputs:
input 1: ('DeepSpeed is a machine learning framework',)
output 1: ('DeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch, and it is compatible with the existing PyTorch ecosystem. DeepSpeed is designed to be easy to use, and it provides a number of features that make it easy to train large-scale models',)

input 2: ('He is working on',)
output 1: ('He is working on a new project, which is a sequel to his 2016 film, <em>The Legend of Michael Mishra</em>.\n\n“I am working on a sequel to <em>The Legend of Michael Mishra</em>. It is a comedy film. I am writing the script and I will start shooting for it in the',)

input 3: ('He has a',)
output 1: ('He has a very good knowledge of the market and is very professional. He is very helpful and always available to answer any questions.\n\nI would highly recommend him to anyone looking to buy or sell a property.\n\nWe were very happy with the service provided by the team at Ray White. They were very professional and knowledgeable, and they',)

input 4: ('He got all',)
output 1: ('He got all the way to the final of the 2019 edition of the show, but this year he’s back with a bang.\n\nThe 26-year-old from the Isle of Wight is a professional dancer and choreographer who has worked with the likes of Little Mix, Olly Murs and Fleur',)

input 5: ('Everyone is happy and I can',)
output 1: ('Everyone is happy and I can’t wait to see what the future holds for us.\n\nI’m so happy to have found a place that I can call home.\n\nI’m so happy to have found a place that I can call home.\n\nI’m so happy to have found a place that I can call home.\n\n',)

input 6: ('The new movie that got Oscar this year',)
output 1: ('The new movie that got Oscar this year is a movie that is based on a true story. The movie is called “The Imitation Game”. The movie is about a man named Alan Turing who was a mathematician and a code breaker. He was a very smart man and he was able to break the code that the Germans were using to communicate with each other. He',)

input 7: ('In the far far distance from our galaxy,',)
output 1: ('In the far far distance from our galaxy, there is a planet called Earth. On this planet, there are many different species of animals. One of them is the human.\n\nThe human is a very special species. They have a very high intelligence and they can create many things. They can create a lot of things that can help them to survive.\n\nOne',)

input 8: ('Peace is the only way',)
output 1: ('Peace is the only way to solve the conflict in South Sudan, the country’s President Salva Kiir has said.\n\nKiir made the remarks on Wednesday during the 10th anniversary of the Comprehensive Peace Agreement (CPA) in Juba.\n\n“The only way to solve the conflict in South Sudan is through peace. We have',)


Stats:
----------------------------------------------------------------------------------
Input tokens
Throughput (including tokenization) = 786.3954218674475 tokens/second
Memory allocated                    = 18.53 GB
Max memory allocated                = 18.66 GB
Total memory available              = 94.62 GB
Graph compilation duration          = 2.820120850999956 seconds
----------------------------------------------------------------------------------

@atakaha
Copy link
Contributor Author

atakaha commented Nov 5, 2024

FP8 is same quality on my side. And FP8 with flash attention drops throughput.

  • BF16 base command line
    python run_generation.py --model_name_or_path google/gemma-7b --use_hpu_graphs --trim_logits --use_kv_cache --reuse_cache --max_input_tokens 128 --max_new_tokens 128 --bf16 --batch_size 128
quantize batch_size max_input_tokens max_new_tokens use_flash_attention flash_attention_recompute flash_attention_causal_mask attn_softmax_bf16 Throughput Memory allocated Max memory allocated
bf16 128 128 128 4515.519 79 80.97
bf16 128 128 128 4514.927 79.02 81
bf16 128 128 128 4540.38 78.99 80.97
bf16 128 128 128 4535.465 78.98 80.97
  • FP8 managements with/without flash attention are done separately, sine script path is different and it cause error.

    • without flash attention
      QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py --model_name_or_path google/gemma-7b --use_hpu_graphs --trim_logits --use_kv_cache --reuse_cache --max_input_tokens 128 --max_new_tokens 128 --bf16 --batch_size 1
    • with flash attention
      QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py --model_name_or_path google/gemma-7b --use_hpu_graphs --trim_logits --use_kv_cache --reuse_cache --max_input_tokens 128 --max_new_tokens 128 --bf16 --batch_size 1 --use_flash_attention
  • FP8 base command line.
    UANT_CONFIG=./quantization_config/maxabs_quant.json run_generation.py --model_name_or_path google/gemma-7b --use_hpu_graphs --trim_logits --use_kv_cache --reuse_cache --max_input_tokens 128 --max_new_tokens 128 --bf16 --batch_size 128

quantize batch_size max_input_tokens max_new_tokens use_flash_attention flash_attention_recompute flash_attention_causal_mask attn_softmax_bf16 Throughput Memory allocated Max memory allocated
fp8 128 128 128 8029.392 64.04 65.9
fp8 128 128 128 8043.598 64.04 65.9
fp8 128 128 128 3596.054 64.03 65.88
fp8 128 128 128 3593.654 64.03 65.88

@atakaha
Copy link
Contributor Author

atakaha commented Nov 5, 2024

@atakaha can you paste commands and outputs(throughput, text) for 1 and 8 HPU w/ bf16 and fp8 with these changes as Thanaji had requested? As mentioned on ticket perhaps you can file new ones for issues you discovered.

@vidyasiv, tickets are created.

@vidyasiv
Copy link
Contributor

vidyasiv commented Nov 6, 2024

@regisss could you take a look. Pending issue (FP8 with flash attention drops throughput.)has ticket filed

Copy link
Contributor

@vidyasiv vidyasiv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@atakaha
Copy link
Contributor Author

atakaha commented Nov 6, 2024

For FP8, we need to use quantization_config/maxabs_quant_gemma.json for measurement. Then we get accurate output for FP8

@atakaha
Copy link
Contributor Author

atakaha commented Nov 14, 2024

@regisss , Please review this PR.

@libinta libinta added the run-test Run CI for PRs from external contributors label Nov 14, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@regisss regisss merged commit ef83544 into huggingface:main Nov 15, 2024
3 of 5 checks passed
Luca-Calabria pushed a commit to Luca-Calabria/optimum-habana that referenced this pull request Nov 25, 2024
Liangyx2 pushed a commit to HabanaAI/optimum-habana-fork that referenced this pull request Jan 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-test Run CI for PRs from external contributors
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants