Skip to content

Commit

Permalink
gpt_bigcode: added internal bucketing fix (#1526)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonchar authored Dec 9, 2024
1 parent 491e8ab commit 9a4c6de
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: Optional[int] = None,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
Expand Down Expand Up @@ -334,18 +335,37 @@ def forward(

key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)

if layer_past is not None:
_, q_len, _ = hidden_states.size()
bucket_internal_decode_stage = cache_idx is not None and q_len == 1

if not bucket_internal_decode_stage:
if layer_past is not None:
past_key, past_value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
if token_idx is not None:
# Using out of place version of index_add_() to ensure the intermediate tensors are not lost when HPU graphs are enabled.
key = past_key.index_add(1, token_idx - 1, key - torch.index_select(past_key, 1, token_idx - 1))
value = past_value.index_add(
1, token_idx - 1, value - torch.index_select(past_value, 1, token_idx - 1)
)
else:
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = torch.cat((key, value), dim=-1) if use_cache else None
else:
assert token_idx is not None, "Invalid parameters: token_idx is None at decode stage with bucket_internal"
assert (
layer_past is not None
), "Invalid parameters: layer_past is None at decode stage with bucket_internal"

past_key, past_value = layer_past.split((self.head_dim, self.head_dim), dim=-1)
if token_idx is not None:
# Using out of place version of index_add_() to ensure the intermediate tensors are not lost when HPU graphs are enabled.
key = past_key.index_add(1, token_idx - 1, key - torch.index_select(past_key, 1, token_idx - 1))
value = past_value.index_add(
1, token_idx - 1, value - torch.index_select(past_value, 1, token_idx - 1)
)
else:
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = torch.cat((key, value), dim=-1) if use_cache else None
key = past_key.index_copy_(1, token_idx - 1, key)
value = past_value.index_copy_(1, token_idx - 1, value)
present = layer_past

if bucket_internal_decode_stage:
key = key[:, :cache_idx, :]
value = value[:, :cache_idx, :]
attention_mask = attention_mask[:, :, :, :cache_idx]

if not output_attentions and head_mask is None and use_flash_attention:
# Difference with the original implementation: there is no need to transpose the key here,
Expand All @@ -367,6 +387,11 @@ def forward(
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

if bucket_internal_decode_stage:
# Return only past key value shapes and not the tensors during decode phase (q len is 1)
# to avoid making past key values as persistent output tensors of HPU graphs.
present = present.shape

outputs = (attn_output, present)
if output_attentions:
if self.multi_query:
Expand All @@ -392,6 +417,7 @@ def gaudi_gpt_bigcode_block_forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: Optional[int] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Expand All @@ -413,6 +439,7 @@ def gaudi_gpt_bigcode_block_forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
Expand Down Expand Up @@ -475,6 +502,7 @@ def gaudi_gpt_bigcode_model_forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: Optional[int] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
"""
Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Expand Down Expand Up @@ -638,6 +666,7 @@ def gaudi_gpt_bigcode_model_forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -750,6 +779,7 @@ def prepare_inputs_for_generation(
"flash_attention_recompute": kwargs.get("flash_attention_recompute", False),
"flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax", False),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask", False),
"cache_idx": kwargs.get("cache_idx", None),
}
)
return model_inputs
Expand All @@ -775,6 +805,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: Optional[int] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -803,6 +834,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)
hidden_states = transformer_outputs[0]

Expand Down
10 changes: 2 additions & 8 deletions tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,15 @@
("EleutherAI/gpt-neox-20b", 1, False, 50.67672679310354, False),
("meta-llama/Llama-2-7b-hf", 1, True, 141.25776956002076, True),
("tiiuae/falcon-40b", 1, True, 25.202450111088346, False),
(
"bigcode/starcoder",
256,
True,
6846.575763562658,
False,
), # TODO: Enable check_output after model bigcode/starcoder is fixed
("bigcode/starcoder", 256, True, 6846.575763562658, True),
("Salesforce/codegen2-1B", 1, False, 446.4029486883532, False),
("mosaicml/mpt-30b", 1, False, 36.06464336116623, False),
("mistralai/Mistral-7B-v0.1", 1, True, 130.2172236767782, True),
("mistralai/Mixtral-8x7B-v0.1", 1, False, 23.7931001677926, True),
("microsoft/phi-2", 1, False, 224.72307766211117, False),
("meta-llama/Meta-Llama-3-8B", 1, True, 129, False),
("meta-llama/Llama-2-7b-hf", 512, True, 12808, False),
("meta-llama/Llama-2-7b-hf", 512, False, 8711, False), # in some cases like TGI, reuse_cache isnt used
("meta-llama/Llama-2-7b-hf", 512, False, 8711, False), # in some cases like TGI, reuse_cache isn't used
("stabilityai/stablelm-2-12b", 1, False, 74.8904496532218, False),
("codellama/CodeLlama-34b-hf", 1, True, 32.644, False),
("bigcode/starcoder2-3b", 1, False, 261.07213776344133, True),
Expand Down

0 comments on commit 9a4c6de

Please sign in to comment.