Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA optimized Implementation of StaticCache with Tensor Indexing API #31129

Closed
wants to merge 14 commits into from

Conversation

huzama
Copy link

@huzama huzama commented May 30, 2024

Use the Index_copy method to update the static cache inplace and avoid recompilation during each iteration in XLA

What does this PR do?

The PR focuses on avoiding repeated recompilation during each iteration in XLA by performing in-place updates on the static cache. Specifically, it replaces direct tensor indexing assignments with index_copy_ method calls. This technique ensures that cache updates are executed without triggering full tensor recompilations, improving runtime performance.

Code Changes

Original Code:

k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

Updated Code:

k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)

Before submitting

@zucchini-nlp @gante

Edit:

  • Discussion: 31126

  • The code was further updated to create a new list instead of performing in-place updates to avoid recompilation in XLA.

@huzama huzama marked this pull request as draft May 30, 2024 05:13
@huzama huzama marked this pull request as ready for review May 30, 2024 09:55
@huzama huzama marked this pull request as draft May 31, 2024 04:52
@LysandreJik
Copy link
Member

cc @ArthurZucker as well

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@huzama
Copy link
Author

huzama commented Jun 10, 2024

Summary of Changes

@tengomucho, I conducted experiments with both .index_copy_ and Python slicing methods on the GPU and found no significant performance differences. Consequently, I have reverted the StaticCache code to its original form.

In LLaMA, the KV-cache tensor slices are updated in-place; this leads to recompilation events every time a token is generated. To address this issue, we use index tensors and tensor.index_copy() ops to replace the in-place slice updates. Attention masks and output sequences also benefit from the same optimization. [Ref]

Based on @ArthurZucker's suggestions, and above-mentioned guide, I created a new class, StaticCacheXLA, with several key differences from the simple StaticCache class:

  1. Updating Cache with Out-of-Place Operations with index_copy
    XLA lazy tensors perform well with out-of-place operations:

    k_out = self.key_cache[layer_idx]
    v_out = self.value_cache[layer_idx]
    k_out = k_out.index_copy(2, cache_position, key_states)
    v_out = v_out.index_copy(2, cache_position, value_states)
    self.key_cache[layer_idx] = k_out
    self.value_cache[layer_idx] = v_out

  2. Get seq_len Out-of-Place with index_select

    item = key_cache.index_select(0, torch.tensor(0, device=device))
    head = item.index_select(1, torch.tensor(0, device=device))
    return head.any(dim=-1).sum()

Performance Improvements

Architecture Design

Using SPMD on LLAMA 3 following the guide from PyTorch's high-performance LLAMA 2 blog and a simple generate function with greedy decoding, I achieved the following results:

Results

These experiments were run on a TPU v3-8. The input sequence length was 256, and the maximum number of new tokens was 512. I can upload numbers on TPU v4-128 later.

As a result of these changes and optimizations, the new StaticCacheXLA class achieved a generation rate of 6.35 iterations per second with TPU utilization at 31%. In comparison, the original StaticCache achieved 5.5 iterations per second with TPU utilization at 27%.

  • StaticCacheXLA: 6.35 it/s, TPU utilization: 3.1%
  • StaticCache: 5.5 it/s, TPU utilization: 2.7%
  • Without Cache: 3 s/it, TPU utilization: 79%

Edit:
The significant disparity in TPU utilization arises from the difference between single token generation during cache-enabled runs and the full input processing during non-cache runs. Despite the low TPU utilization, cache-enabled generation is 10 times faster. Additionally, when employing in-place cache updates or Python slicing, the model recompiles the graph in proportion to the number of layers it contains. A similar pattern is observed with XLA on GPUs. I can provide further details, including the relevant code and performance metrics if needed.

@huzama huzama marked this pull request as ready for review June 10, 2024 12:12
@huzama huzama changed the title Use the Index_copy method to update static cache inplace Out-of-Place updates to StaticCache for XLA Jun 10, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general looks ok, we need to make sure generate can use it so update the mapping to detect if xla is used then use this class WDYT?

@huzama
Copy link
Author

huzama commented Jun 13, 2024

@ArthurZucker Thank you for the feedback. If I understand correctly, you are suggesting that we should dynamically map to the appropriate cache implementation based on whether XLA is available. Would the following update reflect your suggestion?

class StaticCacheDefault(Cache):
    """Default static cache implementation."""
    pass

class StaticCacheXLA(Cache):
    """XLA-optimized static cache implementation."""
    pass

# Determine which StaticCache implementation to use based on the availability of XLA.
StaticCache = StaticCacheXLA if is_torch_xla_available() else StaticCacheDefault

Copy link
Contributor

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before Merging this, I would suggest to do another change. I think you can remove all methods except update andget_seq_length, that are the only two functions that are different from the parent class.

Anyway this seems great, a quick test on a TPU showed me a speedup of 10% in inference on gemma-2b!

@huzama huzama requested a review from tengomucho June 14, 2024 01:33
@huzama huzama changed the title Out-of-Place updates to StaticCache for XLA XLA optimized Implementation of StaticCache with Tensor Indexing API Jun 14, 2024
Copy link
Contributor

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@huzama huzama requested a review from tengomucho June 14, 2024 12:14
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core of the PR LGTM, thank you for working on it! 💛

Open questions before merging:

  1. Utilization in generate: as we are adding more Cache classes, we are noticing that automagically initializing a Cache inside generate is becoming tricky. For my end, all that we need is to confirm that we can use this class as
xla_cache = StaticCacheXLA(...)
generate_outputs = model.generate(**inputs, past_key_values=xla_cache)
  1. Tests :D Can we add a test to confirm that the new class doesn't change the model outputs? We can also test that the API described in 1. works in that test (example of a similar test:
    def test_dynamic_cache_hard(self):
    )

@huzama
Copy link
Author

huzama commented Jun 14, 2024

Hello @gante,

A quick test on GPUs (as my TPUs are running jobs) with the following script revealed it can be used and it produces the same results as the original StaticCache.

from transformers import LlamaForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCacheXLA, StaticCache
import torch

model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B", torch_dtype=torch.bfloat16
).to(1)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

xla_cache = StaticCacheXLA(model.config, 1, 128, dtype=torch.bfloat16, device=1)
cache = StaticCache(model.config, 1, 128, dtype=torch.bfloat16, device=1)

input_ids = tokenizer("This is a test", return_tensors="pt").to(1)

out = model.generate(
    **input_ids, max_new_tokens=32, do_sample=False, past_key_values=cache
)

out_cache = model.generate(
    **input_ids, max_new_tokens=32, do_sample=False, past_key_values=xla_cache
)

torch.all(out == out_cache) 
# True

@gante
Copy link
Member

gante commented Jun 14, 2024

@huzama awesome! Can we add it as a test? 💛 After the test is added, I'm more than happy to approve the PR!

@tengomucho
Copy link
Contributor

@gante I even wonder if it would be possible just to patch the original StaticCache class, if they behave the same?

@gante
Copy link
Member

gante commented Jun 14, 2024

@tengomucho if all that's needed is to replace slicing by .index_copy_, I'm happy with changing StaticCache itself -- as long as we add a comment like "k_out.index_copy_(2, cache_position, key_states) is equivalent to k_out[:, :, cache_position] = key_states, but with better generalized support", for readability

If that's the case, no new tests would be needed 🤗 @huzama

@huzama
Copy link
Author

huzama commented Jun 14, 2024

@gante I agree that combining both classes into a single implementation is a better approach. I'll add comments to explain the code and will push the changes with the merged implementation.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for iterating 🤗

@gante
Copy link
Member

gante commented Jun 14, 2024

@huzama you'll need to run make fixup on your transformers root dir and push the changes to make our CI happy 🤗

@gante gante requested a review from ArthurZucker June 14, 2024 14:33
@gante
Copy link
Member

gante commented Jun 14, 2024

(unrelated CI failure also observed on fresh PRs from main)

Copy link
Contributor

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, could you make sure you run the slow tests for Llama compile, gemma compile etc ?

Comment on lines +875 to +885
key_cache = self.key_cache[layer_idx]
device = key_cache.device

# index_select(dim, index) performs the same operation as item = tensor[..., index, ...]
# but it is used for better generality and flexibility.
# For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html

item = key_cache.index_select(0, torch.tensor(0, device=device))
head = item.index_select(1, torch.tensor(0, device=device))

return head.any(dim=-1).sum()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH this will be deprecated anyways!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@tengomucho
Copy link
Contributor

Hey @huzama are you planning to clean this up to get it merged? Otherwise let me know, I will be happy to take over to make StaticCache more efficient on XLA!

tengomucho added a commit to huggingface/optimum-tpu that referenced this pull request Jul 8, 2024
This is actually a ripoff of the work originally done as a contribution
to transformers:

huggingface/transformers#31129

The original contribution has not been merged yet, but it shows lower
memory usage and better performance on XLA. So I think it's worth adding
it here.
tengomucho added a commit to huggingface/optimum-tpu that referenced this pull request Jul 8, 2024
This is actually a ripoff of the work originally done as a contribution
to transformers:

huggingface/transformers#31129

The original contribution has not been merged yet, but it shows lower
memory usage and better performance on XLA. So I think it's worth adding
it here, to be integrated on optimum-tpu.
tengomucho added a commit to huggingface/optimum-tpu that referenced this pull request Jul 8, 2024
This is actually a ripoff of the work originally done as a contribution
to transformers:

huggingface/transformers#31129

The original contribution has not been merged yet, but it shows lower
memory usage and better performance on XLA. So I think it's worth adding
it here, to be integrated on optimum-tpu.
@huzama
Copy link
Author

huzama commented Jul 9, 2024

@tengomucho Thank you for taking over. I have been running some other experiments, and I lost track of this pull request. I appreciate your help in making StaticCache more efficient on XLA!

tengomucho added a commit to huggingface/optimum-tpu that referenced this pull request Jul 9, 2024
This is actually a ripoff of the work originally done as a contribution
to transformers:

huggingface/transformers#31129

The original contribution has not been merged yet, but it shows lower
memory usage and better performance on XLA. So I think it's worth adding
it here, to be integrated on optimum-tpu.
@tengomucho
Copy link
Contributor

I opened another PR so I could push on the branch. Closing this!

@tengomucho tengomucho closed this Jul 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants