Skip to content

Commit

Permalink
Enable Gemma2 Inference on Gaudi (#1504)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
Signed-off-by: Ye, Xinyu <[email protected]>
Signed-off-by: Mengni Wang <[email protected]>
Signed-off-by: Daniel Socek <[email protected]>
Co-authored-by: billishyahao <[email protected]>
Co-authored-by: Harish Subramony <[email protected]>
Co-authored-by: Yeonsil Yoon <[email protected]>
Co-authored-by: Seunghyuk Park (shepark) <[email protected]>
Co-authored-by: regisss <[email protected]>
Co-authored-by: Sun Choi <[email protected]>
Co-authored-by: xinhe <[email protected]>
Co-authored-by: Mohit Deopujari <[email protected]>
Co-authored-by: Wang, Yi <[email protected]>
Co-authored-by: Soila Kavulya <[email protected]>
Co-authored-by: Iman Gohari <[email protected]>
Co-authored-by: ZhengHongming888 <[email protected]>
Co-authored-by: XinyuYe-Intel <[email protected]>
Co-authored-by: Vivek Goel <[email protected]>
Co-authored-by: Akihiro Takahashi <[email protected]>
Co-authored-by: Miroslav Goncharenko <[email protected]>
Co-authored-by: Wang, Mengni <[email protected]>
Co-authored-by: Daniel Socek <[email protected]>
Co-authored-by: Adam Stachowicz <[email protected]>
Co-authored-by: Vidya Galli <[email protected]>
Co-authored-by: deepak-gowda-narayana <140652370+deepak-gowda-narayana@users.noreply.github.com>
  • Loading branch information
1 parent 82a1c96 commit 9a49200
Show file tree
Hide file tree
Showing 8 changed files with 1,105 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ The following model architectures, tasks and device distributions have been vali
| Qwen2 | <div style="text-align:left"><li>Single card</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2-MoE | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Gemma | :heavy_check_mark: | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Gemma2 | | :heavy_check_mark: | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| XGLM | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Cohere | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| T5 / Flan T5 | :heavy_check_mark: | :heavy_check_mark: | <li>[summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)</li><li>[translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)</li><li>[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20)</li> |
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| Phi || <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Mixtral | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Gemma || <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Gemma2 | || <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2 | <div style="text-align:left"><li>Single card</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2-MoE | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Persimmon | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
Expand Down
10 changes: 9 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"phi",
"mixtral",
"gemma",
"gemma2",
"blip_text_model",
"seamless_m4t",
"starcoder2",
Expand Down Expand Up @@ -961,6 +962,7 @@ def generate(
- [`transformers.generation.GenerateEncoderDecoderOutput`],
- [`transformers.generation.GenerateBeamEncoderDecoderOutput`]
"""

if iteration_times is not None:
hb_gen_time = HabanaGenerationtime(iteration_times=iteration_times)
hb_gen_time.start()
Expand Down Expand Up @@ -1077,15 +1079,19 @@ def generate(
"starcoder2",
"qwen2_moe",
"gemma",
"gemma2",
]
), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma and starcoder2 at the moment"
), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2 and starcoder2 at the moment"
if not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
), "please set bucket_internal along with reuse_cache and bucket_size"
else:
assert generation_config.bucket_size >= 0, "please set valid bucket_size to use bucket_internal"

if self.config.model_type == "gemma2":
generation_config.cache_implementation = None

if generation_config.static_shapes:
# Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs
# In encoder_decoder models, Inputs are already padded
Expand Down Expand Up @@ -1190,6 +1196,7 @@ def generate(
input_ids_length = input_ids.shape[1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None

generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
Expand Down Expand Up @@ -1278,6 +1285,7 @@ def generate(
"gptj",
"starcoder2",
"gemma",
"gemma2",
"qwen2_moe",
]:
if self.config.max_position_embeddings < calculated_max_length:
Expand Down
14 changes: 14 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@
GaudiFalconForCausalLM,
GaudiFalconMLP,
GaudiFalconModel,
GaudiGemma2Attention,
GaudiGemma2DecoderLayer,
GaudiGemma2ForCausalLM,
GaudiGemma2MLP,
GaudiGemma2Model,
GaudiGemma2RotaryEmbedding,
GaudiGemmaAttention,
GaudiGemmaDecoderLayer,
GaudiGemmaForCausalLM,
Expand Down Expand Up @@ -503,6 +509,14 @@ def adapt_transformers_to_gaudi():
transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = GaudiGemmaDecoderLayer
transformers.models.gemma.modeling_gemma.GemmaModel = GaudiGemmaModel

# Optimization for gemma2 on Gaudi
transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM = GaudiGemma2ForCausalLM
transformers.models.gemma2.modeling_gemma2.Gemma2MLP = GaudiGemma2MLP
transformers.models.gemma2.modeling_gemma2.Gemma2Attention = GaudiGemma2Attention
transformers.models.gemma2.modeling_gemma2.Gemma2DecoderLayer = GaudiGemma2DecoderLayer
transformers.models.gemma2.modeling_gemma2.Gemma2Model = GaudiGemma2Model
transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GaudiGemma2RotaryEmbedding

# Optimization for blip Text model on Gaudi
transformers.models.blip.BlipTextModel.forward = gaudi_BlipTextModel_forward
transformers.models.blip.modeling_blip_text.BlipTextLMHeadModel.forward = gaudi_BlipTextLMHead_forward
Expand Down
8 changes: 8 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@
GaudiGemmaMLP,
GaudiGemmaModel,
)
from .gemma2 import (
GaudiGemma2Attention,
GaudiGemma2DecoderLayer,
GaudiGemma2ForCausalLM,
GaudiGemma2MLP,
GaudiGemma2Model,
GaudiGemma2RotaryEmbedding,
)
from .gpt2 import (
GaudiGPT2Attention,
GaudiGPT2Block,
Expand Down
8 changes: 8 additions & 0 deletions optimum/habana/transformers/models/gemma2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .modeling_gemma2 import (
GaudiGemma2Attention,
GaudiGemma2DecoderLayer,
GaudiGemma2ForCausalLM,
GaudiGemma2MLP,
GaudiGemma2Model,
GaudiGemma2RotaryEmbedding,
)
Loading

0 comments on commit 9a49200

Please sign in to comment.