Skip to content

Commit

Permalink
Enable flash attention and reuse_cache for gemma
Browse files Browse the repository at this point in the history
Add missing flag handling to gemma
   --reuse_cache
   --use_flash_attention
   --flash_attention_recompute
   --flash_attention_causal_mask
  • Loading branch information
atakaha committed Oct 30, 2024
1 parent 03fa6dd commit 29d4814
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions optimum/habana/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def pre_attn_forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
cache_idx: Optional[int] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Expand Down Expand Up @@ -407,23 +407,23 @@ def pre_attn(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
cache_idx: Optional[int] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
token_idx,
attn_softmax_bf16,
reuse_cache,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)
return hidden_states, attn_weights, present_key_value
Expand All @@ -443,7 +443,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
cache_idx: Optional[int] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Copied from GemmaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py
Expand All @@ -453,16 +453,16 @@ def forward(
residual = hidden_states

hidden_states, self_attn_weights, present_key_value = self.pre_attn(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
token_idx,
attn_softmax_bf16,
reuse_cache,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
Expand Down Expand Up @@ -717,6 +717,7 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
reuse_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Expand Down Expand Up @@ -746,6 +747,7 @@ def forward(
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
reuse_cache=reuse_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand Down Expand Up @@ -859,9 +861,13 @@ def prepare_inputs_for_generation(
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"reuse_cache": kwargs.get("reuse_cache")
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
"token_idx": token_idx,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
}
)
return model_inputs

0 comments on commit 29d4814

Please sign in to comment.