-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XLA optimized Implementation of StaticCache with Tensor Indexing API
#31129
Conversation
…recompilation during each iteration in XLA
…ilation during each iteration in XLA
cc @ArthurZucker as well |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI @gante and @zucchini-nlp
…d recompilation during each iteration in XLA" This reverts commit 1ad0a9a.
…tion. This is necessary to for XLA as tensors are not materilzed yet
torch.arange(past_length) where past_length keeps changing causes recompilation in XLA
… isinstance(past_key_value, StaticCache)
Summary of Changes@tengomucho, I conducted experiments with both
Based on @ArthurZucker's suggestions, and above-mentioned guide, I created a new class,
Performance ImprovementsArchitecture DesignUsing SPMD on LLAMA 3 following the guide from PyTorch's high-performance LLAMA 2 blog and a simple generate function with greedy decoding, I achieved the following results: ResultsThese experiments were run on a TPU v3-8. The input sequence length was 256, and the maximum number of new tokens was 512. I can upload numbers on TPU v4-128 later. As a result of these changes and optimizations, the new
Edit: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in general looks ok, we need to make sure generate can use it so update the mapping to detect if xla is used then use this class WDYT?
@ArthurZucker Thank you for the feedback. If I understand correctly, you are suggesting that we should dynamically map to the appropriate cache implementation based on whether XLA is available. Would the following update reflect your suggestion? class StaticCacheDefault(Cache):
"""Default static cache implementation."""
pass
class StaticCacheXLA(Cache):
"""XLA-optimized static cache implementation."""
pass
# Determine which StaticCache implementation to use based on the availability of XLA.
StaticCache = StaticCacheXLA if is_torch_xla_available() else StaticCacheDefault |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before Merging this, I would suggest to do another change. I think you can remove all methods except update
andget_seq_length
, that are the only two functions that are different from the parent class.
Anyway this seems great, a quick test on a TPU showed me a speedup of 10% in inference on gemma-2b!
… remove unnecessory code
Tensor Indexing API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The core of the PR LGTM, thank you for working on it! 💛
Open questions before merging:
- Utilization in
generate
: as we are adding moreCache
classes, we are noticing that automagically initializing aCache
insidegenerate
is becoming tricky. For my end, all that we need is to confirm that we can use this class as
xla_cache = StaticCacheXLA(...)
generate_outputs = model.generate(**inputs, past_key_values=xla_cache)
- Tests :D Can we add a test to confirm that the new class doesn't change the model outputs? We can also test that the API described in 1. works in that test (example of a similar test:
transformers/tests/test_cache_utils.py
Line 171 in 43ee585
def test_dynamic_cache_hard(self):
Hello @gante, A quick test on GPUs (as my TPUs are running jobs) with the following script revealed it can be used and it produces the same results as the original from transformers import LlamaForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCacheXLA, StaticCache
import torch
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B", torch_dtype=torch.bfloat16
).to(1)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
xla_cache = StaticCacheXLA(model.config, 1, 128, dtype=torch.bfloat16, device=1)
cache = StaticCache(model.config, 1, 128, dtype=torch.bfloat16, device=1)
input_ids = tokenizer("This is a test", return_tensors="pt").to(1)
out = model.generate(
**input_ids, max_new_tokens=32, do_sample=False, past_key_values=cache
)
out_cache = model.generate(
**input_ids, max_new_tokens=32, do_sample=False, past_key_values=xla_cache
)
torch.all(out == out_cache)
# True |
@huzama awesome! Can we add it as a test? 💛 After the test is added, I'm more than happy to approve the PR! |
@gante I even wonder if it would be possible just to patch the original StaticCache class, if they behave the same? |
@tengomucho if all that's needed is to replace slicing by If that's the case, no new tests would be needed 🤗 @huzama |
@gante I agree that combining both classes into a single implementation is a better approach. I'll add comments to explain the code and will push the changes with the merged implementation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for iterating 🤗
@huzama you'll need to run |
(unrelated CI failure also observed on fresh PRs from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, could you make sure you run the slow tests for Llama compile, gemma compile etc ?
key_cache = self.key_cache[layer_idx] | ||
device = key_cache.device | ||
|
||
# index_select(dim, index) performs the same operation as item = tensor[..., index, ...] | ||
# but it is used for better generality and flexibility. | ||
# For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html | ||
|
||
item = key_cache.index_select(0, torch.tensor(0, device=device)) | ||
head = item.index_select(1, torch.tensor(0, device=device)) | ||
|
||
return head.any(dim=-1).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH this will be deprecated anyways!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hey @huzama are you planning to clean this up to get it merged? Otherwise let me know, I will be happy to take over to make |
This is actually a ripoff of the work originally done as a contribution to transformers: huggingface/transformers#31129 The original contribution has not been merged yet, but it shows lower memory usage and better performance on XLA. So I think it's worth adding it here.
This is actually a ripoff of the work originally done as a contribution to transformers: huggingface/transformers#31129 The original contribution has not been merged yet, but it shows lower memory usage and better performance on XLA. So I think it's worth adding it here, to be integrated on optimum-tpu.
This is actually a ripoff of the work originally done as a contribution to transformers: huggingface/transformers#31129 The original contribution has not been merged yet, but it shows lower memory usage and better performance on XLA. So I think it's worth adding it here, to be integrated on optimum-tpu.
@tengomucho Thank you for taking over. I have been running some other experiments, and I lost track of this pull request. I appreciate your help in making StaticCache more efficient on XLA! |
This is actually a ripoff of the work originally done as a contribution to transformers: huggingface/transformers#31129 The original contribution has not been merged yet, but it shows lower memory usage and better performance on XLA. So I think it's worth adding it here, to be integrated on optimum-tpu.
I opened another PR so I could push on the branch. Closing this! |
Use the
Index_copy
method to update the static cache inplace and avoid recompilation during each iteration in XLAWhat does this PR do?
The PR focuses on avoiding repeated recompilation during each iteration in XLA by performing in-place updates on the static cache. Specifically, it replaces direct tensor indexing assignments with
index_copy_
method calls. This technique ensures that cache updates are executed without triggering full tensor recompilations, improving runtime performance.Code Changes
Original Code:
Updated Code:
Before submitting
Pull Request section?
@zucchini-nlp @gante
Edit:
Discussion: 31126