diff --git a/README.md b/README.md
index 9426e94f25..c0acbccfae 100644
--- a/README.md
+++ b/README.md
@@ -214,6 +214,7 @@ The following model architectures, tasks and device distributions have been vali
| Qwen2 |
Single card | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Qwen2-MoE | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Gemma | :heavy_check_mark: | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
+| Gemma2 | | :heavy_check_mark: | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| XGLM | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Cohere | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| T5 / Flan T5 | :heavy_check_mark: | :heavy_check_mark: | [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)[translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20) |
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index 0bd27464d1..73ff5f72db 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -57,6 +57,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| Phi | ✅ | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Mixtral | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Gemma | ✅ | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
+| Gemma2 | | ✅ | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Qwen2 | Single card | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Qwen2-MoE | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Persimmon | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py
index cab34fdc27..012fcccfd1 100644
--- a/optimum/habana/transformers/generation/utils.py
+++ b/optimum/habana/transformers/generation/utils.py
@@ -94,6 +94,7 @@
"phi",
"mixtral",
"gemma",
+ "gemma2",
"blip_text_model",
"seamless_m4t",
"starcoder2",
@@ -961,6 +962,7 @@ def generate(
- [`transformers.generation.GenerateEncoderDecoderOutput`],
- [`transformers.generation.GenerateBeamEncoderDecoderOutput`]
"""
+
if iteration_times is not None:
hb_gen_time = HabanaGenerationtime(iteration_times=iteration_times)
hb_gen_time.start()
@@ -1077,8 +1079,9 @@ def generate(
"starcoder2",
"qwen2_moe",
"gemma",
+ "gemma2",
]
- ), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma and starcoder2 at the moment"
+ ), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2 and starcoder2 at the moment"
if not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
@@ -1086,6 +1089,9 @@ def generate(
else:
assert generation_config.bucket_size >= 0, "please set valid bucket_size to use bucket_internal"
+ if self.config.model_type == "gemma2":
+ generation_config.cache_implementation = None
+
if generation_config.static_shapes:
# Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs
# In encoder_decoder models, Inputs are already padded
@@ -1190,6 +1196,7 @@ def generate(
input_ids_length = input_ids.shape[1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
@@ -1278,6 +1285,7 @@ def generate(
"gptj",
"starcoder2",
"gemma",
+ "gemma2",
"qwen2_moe",
]:
if self.config.max_position_embeddings < calculated_max_length:
diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py
index d28f3b5ca5..b43e283595 100644
--- a/optimum/habana/transformers/modeling_utils.py
+++ b/optimum/habana/transformers/modeling_utils.py
@@ -48,6 +48,12 @@
GaudiFalconForCausalLM,
GaudiFalconMLP,
GaudiFalconModel,
+ GaudiGemma2Attention,
+ GaudiGemma2DecoderLayer,
+ GaudiGemma2ForCausalLM,
+ GaudiGemma2MLP,
+ GaudiGemma2Model,
+ GaudiGemma2RotaryEmbedding,
GaudiGemmaAttention,
GaudiGemmaDecoderLayer,
GaudiGemmaForCausalLM,
@@ -503,6 +509,14 @@ def adapt_transformers_to_gaudi():
transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = GaudiGemmaDecoderLayer
transformers.models.gemma.modeling_gemma.GemmaModel = GaudiGemmaModel
+ # Optimization for gemma2 on Gaudi
+ transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM = GaudiGemma2ForCausalLM
+ transformers.models.gemma2.modeling_gemma2.Gemma2MLP = GaudiGemma2MLP
+ transformers.models.gemma2.modeling_gemma2.Gemma2Attention = GaudiGemma2Attention
+ transformers.models.gemma2.modeling_gemma2.Gemma2DecoderLayer = GaudiGemma2DecoderLayer
+ transformers.models.gemma2.modeling_gemma2.Gemma2Model = GaudiGemma2Model
+ transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GaudiGemma2RotaryEmbedding
+
# Optimization for blip Text model on Gaudi
transformers.models.blip.BlipTextModel.forward = gaudi_BlipTextModel_forward
transformers.models.blip.modeling_blip_text.BlipTextLMHeadModel.forward = gaudi_BlipTextLMHead_forward
diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py
index b752663386..c02e9588f3 100644
--- a/optimum/habana/transformers/models/__init__.py
+++ b/optimum/habana/transformers/models/__init__.py
@@ -77,6 +77,14 @@
GaudiGemmaMLP,
GaudiGemmaModel,
)
+from .gemma2 import (
+ GaudiGemma2Attention,
+ GaudiGemma2DecoderLayer,
+ GaudiGemma2ForCausalLM,
+ GaudiGemma2MLP,
+ GaudiGemma2Model,
+ GaudiGemma2RotaryEmbedding,
+)
from .gpt2 import (
GaudiGPT2Attention,
GaudiGPT2Block,
diff --git a/optimum/habana/transformers/models/gemma2/__init__.py b/optimum/habana/transformers/models/gemma2/__init__.py
new file mode 100644
index 0000000000..4112ec1ed4
--- /dev/null
+++ b/optimum/habana/transformers/models/gemma2/__init__.py
@@ -0,0 +1,8 @@
+from .modeling_gemma2 import (
+ GaudiGemma2Attention,
+ GaudiGemma2DecoderLayer,
+ GaudiGemma2ForCausalLM,
+ GaudiGemma2MLP,
+ GaudiGemma2Model,
+ GaudiGemma2RotaryEmbedding,
+)
diff --git a/optimum/habana/transformers/models/gemma2/modeling_gemma2.py b/optimum/habana/transformers/models/gemma2/modeling_gemma2.py
new file mode 100755
index 0000000000..4196775c19
--- /dev/null
+++ b/optimum/habana/transformers/models/gemma2/modeling_gemma2.py
@@ -0,0 +1,1061 @@
+# coding=utf-8
+# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Gemma2 model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch.nn import CrossEntropyLoss
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from transformers.models.gemma2.modeling_gemma2 import (
+ Gemma2Attention,
+ Gemma2Config,
+ Gemma2DecoderLayer,
+ Gemma2ForCausalLM,
+ Gemma2MLP,
+ Gemma2Model,
+ apply_rotary_pos_emb,
+)
+from transformers.utils import logging
+
+from ....distributed.strategy import DistributedStrategy, NoOpStrategy
+from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask
+
+
+try:
+ from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
+
+ has_fused_rope = True
+except ImportError:
+ has_fused_rope = False
+ print("Not using HPU fused kernel for apply_rotary_pos_emb")
+
+
+try:
+ from habana_frameworks.torch.hpex.kernels import FusedSDPA
+except ImportError:
+ print("Not using HPU fused scaled dot-product attention kernel.")
+ FusedSDPA = None
+
+try:
+ from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm
+
+ has_fused_rms_norm = True
+except ImportError:
+ has_fused_rms_norm = False
+ print("Not using HPU fused kernel for RMSNorm")
+
+import habana_frameworks.torch.core as htcore
+
+
+logger = logging.get_logger(__name__)
+
+
+class GaudiGemma2RotaryEmbedding(torch.nn.Module):
+ def __init__(
+ self,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
+ device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[Gemma2Config] = None,
+ ):
+ super().__init__()
+
+ # TODO (joao): remove the `if` below, only used for BC
+ self.rope_kwargs = {}
+ if config is None:
+ logger.warning_once(
+ "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.45"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
+ else:
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.get_default_dtype()
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
+
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False)
+
+ def _dynamic_frequency_update(self, seq_len, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ # seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @torch.no_grad()
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+
+ if "dynamic" in self.rope_type:
+ self._dynamic_frequency_update(seq_len, device=x.device)
+
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+ if self.attention_scaling == 1.0:
+ return (
+ self._cos_cached[:seq_len].to(dtype=x.dtype),
+ self._sin_cached[:seq_len].to(dtype=x.dtype),
+ )
+ else:
+ return (
+ self._cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
+ self._sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
+ )
+
+
+def gaudi_gemma2_repeat_kv(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ n_rep: int,
+):
+ batch, num_key_value_heads, kv_len, head_dim = key_states.shape
+ if n_rep == 1 or num_key_value_heads == 1:
+ return query_states, key_states, value_states, attention_mask
+
+ new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim)
+ key_states = key_states.reshape(new_kv_shape)
+ value_states = value_states.reshape(new_kv_shape)
+
+ batch, _, q_len, head_dim = query_states.shape
+ new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim)
+ query_states = query_states.reshape(new_q_shape)
+
+ if attention_mask is not None:
+ # Add groups dim and set to 1
+ attention_mask = attention_mask.unsqueeze(1)
+
+ return query_states, key_states, value_states, attention_mask
+
+
+class Matmul(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return torch.matmul(x, y)
+
+
+class KVCache(torch.nn.Module):
+ def __init__(self):
+ super(KVCache, self).__init__()
+ self.cache = None
+ self.inp_seq_len = -1
+
+ def allocate(self, inp_seq_len, dtype, device, shape):
+ if self.cache is None or self.cache.shape != shape:
+ self.inp_seq_len = inp_seq_len
+ self.cache = torch.zeros(shape, dtype=dtype, device=device)
+ else:
+ assert (
+ self.inp_seq_len == inp_seq_len
+ ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
+ self.cache.fill_(0)
+
+ def update(self, prev, cur, dim, idx, inp_seq_len):
+ orig_cur = cur
+ if prev.shape == cur.shape:
+ prev.copy_(cur)
+ return orig_cur
+ if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
+ # Initialize
+ prev[:, :, :inp_seq_len, :].copy_(cur)
+ return orig_cur
+ assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
+ if idx is not None:
+ prev.index_copy_(dim, idx - 1, cur)
+ return prev
+ else:
+ return torch.cat((prev, cur), dim=dim)
+
+ def get_shape(self):
+ if self.cache is None:
+ return None
+ return self.cache.shape
+
+ def forward(self, cur, dim, idx):
+ return self.update(self.cache, cur, dim, idx, self.inp_seq_len)
+
+
+class GaudiGemma2Attention(Gemma2Attention):
+ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+
+ self.matmul_qk = Matmul()
+ self.matmul_av = Matmul()
+ self.k_cache = KVCache()
+ self.v_cache = KVCache()
+ self.inp_seq_len = -1
+ self.norm_factor = 1.0 / math.sqrt(self.head_dim)
+ self.block_size = 4096
+
+ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
+ cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
+ device = self.k_proj.weight.device
+ dtype = self.config.torch_dtype
+ self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
+ self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
+
+ def update_sincos_cache(self, seq_len):
+ # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings
+ # This helps in avoiding creation of these caches during actual model forward pass and
+ # reduce memory consumption and improve performance.
+ if seq_len > self.max_position_embeddings:
+ self.max_position_embeddings = seq_len
+ _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len)
+
+ def reorder(self, tensor, beam_idx, dim_a, dim_b):
+ updated = tensor.index_select(0, beam_idx)
+ tensor.copy_(updated)
+
+ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
+ if self.k_cache.cache is None:
+ return (None, None)
+
+ head_dim = self.k_cache.cache.size(-1)
+ seq_length = self.k_cache.cache.size(-2)
+ self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim)
+ self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim)
+ return (self.k_cache.cache.shape, self.v_cache.cache.shape)
+
+ def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size):
+ """
+ Gaudi version of Flash Attention V1 to support long sequence at prompt phase
+ Causal mask is not supported in this optimization
+ """
+ q_len = query_layer.size(-2)
+ q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size)
+ q_padding = q_tiles * q_block_size - q_len
+ query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0)
+ if attention_mask is not None:
+ attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0)
+
+ row_o_list = []
+ for i in range(q_tiles):
+ s, e = i * q_block_size, (i + 1) * q_block_size
+ row_q = query_layer[:, :, s:e, :]
+ row_mask = attention_mask[:, :, s:e, :]
+ attn_output_partial = FusedSDPA.apply(row_q, key_layer, value_layer, row_mask, dropout_rate, False, None)
+ row_o_list.append(attn_output_partial)
+ attn_output = torch.cat(row_o_list, dim=-2)
+
+ if q_padding != 0:
+ attn_output = attn_output[:, :, :-q_padding, :]
+
+ return attn_output
+
+ def pre_attn_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ attn_softmax_bf16: Optional[bool] = False,
+ reuse_cache: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ cache_idx: int = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ The only differences are:
+ - add new args token_idx
+ - optimize KV cache
+ - add new args attn_softmax_bf16
+ - add new args reuse_cache
+ - add new args use_flash_attention
+ - add new arg flash_attention_recompute
+ """
+ if "padding_mask" in kwargs:
+ logger.warning_once(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if token_idx is None:
+ if hasattr(past_key_value, "get_usable_length"):
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ else:
+ kv_seq_len += past_key_value[0].shape[-2]
+ else:
+ if reuse_cache and not isinstance(past_key_value[0], torch.Tensor):
+ kv_seq_len = past_key_value[0][-2]
+ else:
+ kv_seq_len = past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids)
+
+ if use_cache:
+ # reuse k, v, self_attention
+ if reuse_cache:
+ key_states = self.k_cache(key_states, 2, token_idx)
+ value_states = self.v_cache(value_states, 2, token_idx)
+ past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
+ else:
+ if past_key_value is None:
+ past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device)
+ past_value = torch.zeros(
+ key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
+ )
+ past_key_value = (past_key, past_value)
+ key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
+ value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)
+ if token_idx is None:
+ past_key_value = (key_states, value_states)
+
+ if cache_idx is not None and q_len == 1:
+ key_states = key_states[:, :, :cache_idx, :]
+ value_states = value_states[:, :, :cache_idx, :]
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, :, :, :cache_idx]
+ kv_seq_len = key_states.shape[-2]
+ else:
+ past_key_value = None
+
+ if use_flash_attention and FusedSDPA:
+ import habana_frameworks.torch.hpu as ht
+
+ softmax_mode = "fast" if flash_attention_fast_softmax else "None"
+
+ if q_len == 1:
+ # next token
+ with ht.sdp_kernel(enable_recompute=False):
+ attn_output = FusedSDPA.apply(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None, "None"
+ )
+ else:
+ # first token
+ if flash_attention_causal_mask:
+ # causal masking on first token requires inputs to be of the same length
+ with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
+ attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)
+ else:
+ with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
+ if q_len > 16384:
+ attn_output = self.gaudi_flash_attn_v1(
+ query_states, key_states, value_states, attention_mask, 0.0, self.block_size
+ )
+ htcore.mark_step()
+ else:
+ attn_output = FusedSDPA.apply(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
+ )
+
+ else:
+ query_states, key_states, value_states, attention_mask = gaudi_gemma2_repeat_kv(
+ query_states, key_states, value_states, attention_mask, self.num_key_value_groups
+ )
+
+ attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ if cache_position is not None:
+ causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
+ query_states.dtype
+ )
+ attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = self.matmul_av(attn_weights, value_states)
+ attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1:
+ # Return only past key value shapes and not the tensors during decode phase (q len is 1)
+ # to avoid making past key values as persistent output tensors of HPU graphs.
+ past_key_value = (past_key_value[0].shape, past_key_value[1].shape)
+
+ return attn_output, attn_weights, past_key_value
+
+ def attention_all_reduce(self, attn_output):
+ if hasattr(self.o_proj, "all_reduce"):
+ self.o_proj.all_reduce(attn_output)
+
+ def post_attn_forward(self, attn_output):
+ if hasattr(self.o_proj, "post_all_reduce"):
+ self.o_proj.post_all_reduce(attn_output)
+ return attn_output
+
+
+class GaudiGemma2MLP(Gemma2MLP):
+ def pre_mlp_forward(self, x):
+ inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
+ output = self.down_proj(inputs)
+ return output
+
+ def mlp_all_reduce(self, x):
+ if hasattr(self.down_proj, "all_reduce"):
+ self.down_proj.all_reduce(x)
+
+ def post_mlp_forward(self, x):
+ if hasattr(self.down_proj, "post_all_reduce"):
+ return self.down_proj.post_all_reduce(x)
+ return x
+
+
+class GaudiGemma2DecoderLayer(Gemma2DecoderLayer):
+ def __init__(self, config: Gemma2Config, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.self_attn = GaudiGemma2Attention(config, layer_idx)
+ self.mlp = GaudiGemma2MLP(config)
+
+ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
+ self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
+
+ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
+ return self.self_attn.reorder_kv_cache(beam_idx)
+
+ def update_sincos_cache(self, seq_len):
+ self.self_attn.update_sincos_cache(seq_len)
+
+ def pre_attn(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ attn_softmax_bf16: Optional[bool] = False,
+ reuse_cache: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ cache_idx: int = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ attn_softmax_bf16=attn_softmax_bf16,
+ reuse_cache=reuse_cache,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ flash_attention_causal_mask=flash_attention_causal_mask,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
+ cache_idx=cache_idx,
+ )
+ return hidden_states, attn_weights, present_key_value
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ attn_softmax_bf16: Optional[bool] = False,
+ reuse_cache: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ cache_idx: int = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Copied from GemmaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py
+ The only differences are:
+ - add new args token_idx
+ """
+ residual = hidden_states
+
+ hidden_states, self_attn_weights, present_key_value = self.pre_attn(
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_value,
+ output_attentions,
+ use_cache,
+ cache_position,
+ token_idx,
+ attn_softmax_bf16,
+ reuse_cache,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ flash_attention_causal_mask=flash_attention_causal_mask,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
+ cache_idx=cache_idx,
+ )
+
+ self.self_attn.attention_all_reduce(hidden_states)
+
+ hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual)
+
+ self.mlp.mlp_all_reduce(hidden_states)
+
+ hidden_states = self.post_mlp(hidden_states, residual)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ def post_attn_pre_mlp(self, hidden_states, residual):
+ hidden_states = self.self_attn.post_attn_forward(hidden_states)
+ hidden_states = self.post_attention_layernorm(hidden_states)
+
+ if self.training:
+ hidden_states = hidden_states + residual
+ residual = hidden_states
+ else:
+ residual.add_(hidden_states)
+ hidden_states = residual
+
+ residual = hidden_states
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
+ hidden_states = self.mlp.pre_mlp_forward(hidden_states)
+ return hidden_states, residual
+
+ def post_mlp(self, hidden_states, residual):
+ hidden_states = self.mlp.post_mlp_forward(hidden_states)
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
+
+ if self.training:
+ hidden_states = hidden_states + residual
+ else:
+ residual.add_(hidden_states)
+ hidden_states = residual
+
+ return hidden_states
+
+
+class GaudiGemma2Model(Gemma2Model):
+ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
+ for layer in self.layers:
+ layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
+
+ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
+ return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers)
+
+ def update_sincos_cache(self, seq_len):
+ for layer in self.layers:
+ layer.update_sincos_cache(seq_len)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ attn_softmax_bf16: Optional[bool] = False,
+ reuse_cache: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ cache_idx: int = None,
+ lazy_mode: Optional[bool] = True,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ """
+ Copied from GemmaModel.forward: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py
+ The only differences are:
+ - add new args token_idx
+ - replace _update_causal_mask with _gaudi_prepare_4d_causal_attention_mask
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ self._attn_implementation = "eager"
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape[:2]
+ elif inputs_embeds is not None:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ ignore_cache_position = True # Ignoring cache position for HPU
+ use_new_cache = False # Ignoring new Cache path for HPU
+
+ past_seen_tokens = 0
+
+ if past_key_values is not None and use_cache: # kept for BC (cache positions)
+ if reuse_cache:
+ if isinstance(past_key_values[0][0], torch.Tensor):
+ past_seen_tokens = past_key_values[0][0].shape[2]
+ else:
+ past_seen_tokens = past_key_values[0][0][2]
+ else:
+ if use_new_cache:
+ if not isinstance(past_key_values, StaticCache):
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_seen_tokens = past_key_values.get_seq_length()
+ else:
+ past_seen_tokens = past_key_values[0][0].shape[2]
+
+ if ignore_cache_position is False:
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None and cache_position:
+ position_ids = cache_position.unsqueeze(0)
+ else:
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device
+ )
+ position_ids = position_ids.unsqueeze(0)
+ cache_position = None
+
+ # HPU specific mask generation
+ if ignore_cache_position:
+ causal_mask = _gaudi_prepare_4d_causal_attention_mask(
+ attention_mask,
+ input_ids.shape if input_ids is not None else (batch_size, seq_length),
+ inputs_embeds,
+ past_seen_tokens,
+ )
+ else:
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype, device=inputs_embeds.device)
+ hidden_states = hidden_states * normalizer
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if not use_new_cache else None
+
+ if lazy_mode:
+ htcore.mark_step()
+
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ if (
+ lazy_mode
+ and not self.training
+ and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)
+ ):
+ htcore.mark_step()
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ None,
+ attn_softmax_bf16,
+ False,
+ use_flash_attention,
+ flash_attention_recompute,
+ flash_attention_causal_mask,
+ flash_attention_fast_softmax,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=None if past_key_values is None else past_key_values[layer_idx],
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ attn_softmax_bf16=attn_softmax_bf16,
+ reuse_cache=reuse_cache,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ flash_attention_causal_mask=flash_attention_causal_mask,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
+ cache_idx=cache_idx,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = (
+ next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
+ )
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class GaudiGemma2ForCausalLM(Gemma2ForCausalLM):
+ def __init__(self, config, parallel_strategy: DistributedStrategy = NoOpStrategy):
+ config.parallel_strategy = parallel_strategy
+ super().__init__(config)
+
+ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
+ self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
+
+ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
+ return self.model.reorder_kv_cache(beam_idx)
+
+ def update_sincos_cache(self, seq_len):
+ self.model.update_sincos_cache(seq_len)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ trim_logits: Optional[bool] = False,
+ attn_softmax_bf16: Optional[bool] = False,
+ reuse_cache: Optional[bool] = False,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ flash_attention_causal_mask: Optional[bool] = False,
+ flash_attention_fast_softmax: Optional[bool] = False,
+ cache_idx: int = None,
+ lazy_mode: Optional[bool] = True,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ """
+ Inherits from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py
+ The only differences are:
+ - add new args token_idx
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ attn_softmax_bf16=attn_softmax_bf16,
+ reuse_cache=reuse_cache,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ flash_attention_causal_mask=flash_attention_causal_mask,
+ flash_attention_fast_softmax=flash_attention_fast_softmax,
+ cache_idx=cache_idx,
+ lazy_mode=lazy_mode,
+ )
+
+ hidden_states = outputs[0]
+ _, seq_len, _ = hidden_states.shape
+
+ if seq_len > 1 and trim_logits and not self.training:
+ if token_idx is not None:
+ hidden_states = hidden_states.index_select(1, token_idx - 1)
+ else:
+ hidden_states = hidden_states[:, -1, :]
+
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ """
+ Inherits from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py
+ The only differences are:
+ - add new args token_idx
+ - add token_idx into model_inputs
+ - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx
+ - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx
+ """
+
+ reuse_cache = kwargs.get("reuse_cache")
+ bucket_internal = kwargs.get("bucket_internal")
+
+ token_idx = kwargs.get("token_idx", None)
+
+ if past_key_values is not None:
+ if token_idx is None:
+ if inputs_embeds is not None: # Exception 1
+ input_ids = input_ids[:, -cache_position.shape[0] :]
+ elif (
+ input_ids.shape[1] != cache_position.shape[0]
+ ): # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+ else:
+ # past_length += token_idx
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+ elif (reuse_cache or bucket_internal) and token_idx is not None:
+ # KV cache is pre allocated with reuse cache or will be padded with bucket internal
+ # hence for the 1st token we can slice the inputs till token idx for the fwd pass.
+ input_ids = input_ids[:, :token_idx]
+ attention_mask = attention_mask[:, :token_idx]
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ if token_idx is not None:
+ position_ids = torch.index_select(position_ids, 1, token_idx - 1)
+ else:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ if token_idx is None:
+ if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
+ # generation with static cache
+ past_length = past_key_value.get_seq_length()
+ input_ids = input_ids[:, past_length:]
+ position_ids = position_ids[:, past_length:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids.contiguous()}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ "token_idx": token_idx,
+ "trim_logits": kwargs.get("trim_logits"),
+ "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
+ "reuse_cache": reuse_cache,
+ "use_flash_attention": kwargs.get("use_flash_attention"),
+ "flash_attention_recompute": kwargs.get("flash_attention_recompute"),
+ "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
+ "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"),
+ "cache_idx": kwargs.get("cache_idx"),
+ "lazy_mode": kwargs.get("lazy_mode"),
+ }
+ )
+ return model_inputs
+
+
+def apply_customized_rope(q, k, cos, sin, position_ids):
+ if q.device.type == "hpu" and has_fused_rope:
+ # TODO: remove `.clone()` when it is fixed in SynapseAI
+ if k.dtype == torch.bfloat16:
+ return FusedRoPE.apply(
+ q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
+ ), FusedRoPE.apply(
+ k,
+ cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
+ sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
+ position_ids,
+ )
+ return FusedRoPE.apply(
+ q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
+ ), FusedRoPE.apply(
+ k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
+ )
+ else:
+ # keep the same implementation as Transformers v4.37.2
+ return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids])
diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py
index 9a6bdbb8f2..e217e340f2 100644
--- a/tests/test_text_generation_example.py
+++ b/tests/test_text_generation_example.py
@@ -48,6 +48,7 @@
("adept/persimmon-8b-base", 4, False, 366.73968820698406, False),
("Qwen/Qwen1.5-7B", 4, False, 490.8621617893209, False),
("google/gemma-7b", 1, False, 109.70751574382221, True),
+ ("google/gemma-2-9b", 1, False, 92.302359446567, True),
("state-spaces/mamba-130m-hf", 1536, False, 5385.511100161605, False),
("Deci/DeciLM-7B", 1, False, 120, False),
("Qwen/Qwen2-7B", 512, False, 9669.45787, True),
@@ -85,6 +86,7 @@
("meta-llama/Llama-2-70b-hf", 8, 1, 64.10514998902435),
("meta-llama/Meta-Llama-3-70B-Instruct", 8, 1, 64),
("facebook/opt-66b", 2, 1, 28.48069266504111),
+ ("google/gemma-2-9b", 8, 1, 110.12610917383735),
],
"torch_compile": [
("meta-llama/Llama-2-7b-hf", 102.27823420713148),
@@ -103,6 +105,7 @@
"bigcode/starcoder": 'def print_hello_world():\n print("Hello World")\n\ndef print_hello_world_twice():\n print_hello_world()\n print_hello_world()\n\ndef print_hello_world_thrice():\n print_hello_world()\n print_hello_world()\n print_hello_world()\n\ndef print_hello_world_four_times():\n print_hello_world()\n print_hello_world()\n print_hello_world()\n ',
"bigcode/starcoder2-3b": 'def print_hello_world():\n print("Hello World")\n\ndef print_hello_world_with_name(name):\n print("Hello World, " + name)\n\ndef print_hello_world_with_name_and_age(name, age):\n print("Hello World, " + name + ", " + str(age))\n\ndef print_hello_world_with_name_and_age_and_gender(name, age, gender):\n print("Hello',
"google/gemma-7b": "DeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch, and it is compatible with the existing PyTorch ecosystem. DeepSpeed is designed to be easy to use, and it provides a number of features that make it easy to train large-scale models.\n\nDeepSpeed is a machine learning framework that enables training of large-scale models on commodity hardware. It is designed to be a drop-in replacement for PyTorch, and",
+ "google/gemma-2-9b": "DeepSpeed is a machine learning framework that enables training of large-scale deep learning models on a single GPU or across multiple GPUs. It is designed to be easy to use and highly scalable, making it a powerful tool for researchers and practitioners working with large-scale deep learning models.\n\nDeepSpeed is built on top of PyTorch, a popular deep learning framework, and provides a set of tools and libraries that make it easy to train large-scale models. It includes features such as zero-shot inference, which allows models to be",
"meta-llama/Llama-2-7b-hf": "DeepSpeed is a machine learning framework for deep learning. It is designed to be fast and efficient, while also being easy to use. DeepSpeed is based on the TensorFlow framework, and it uses the TensorFlow library to perform computations.\nDeepSpeed is a deep learning framework that is designed to be fast and efficient. It is based on the TensorFlow library and uses the TensorFlow library to perform computations. DeepSpeed is designed to be easy to use and to provide a high level of flex",
"mistralai/Mistral-7B-v0.1": "DeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system.\n\nDeepSpeed is a machine learning framework that accelerates training of large models on a single machine or distributed systems. It is designed to be compatible with PyTorch and TensorFlow, and can be used to train models on a single machine or on a distributed system",
"mistralai/Mixtral-8x7B-v0.1": "DeepSpeed is a machine learning framework that enables training of large models on a single machine with a single GPU. It is designed to be easy to use and efficient, and it can be used to train models on a variety of tasks.\n\n## Introduction\n\nDeepSpeed is a machine learning framework that enables training of large models on a single machine with a single GPU. It is designed to be easy to use and efficient, and it can be used to train models on a variety of tasks.\n\n## What is DeepSpeed",