Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA optimized Implementation of StaticCache with Tensor Indexing API #31129

Closed
wants to merge 14 commits into from
Closed
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_quanto_available, logging


if is_quanto_available():
from quanto import QBitsTensor, qint2, qint4

Expand Down Expand Up @@ -935,3 +934,53 @@ def get_max_length(self) -> Optional[int]:
def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()

class StaticCacheXLA(StaticCache):
""" XLA optimized Implementation of StaticCache """
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how where to write in the cache.

Return:
A tuple containing the updated key and value states.
"""
cache_position = cache_kwargs.get("cache_position")

k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)

return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
key_cache = self.key_cache[layer_idx]
device = key_cache.device

item = key_cache.index_select(0, torch.tensor(0, device=device))
head = item.index_select(1, torch.tensor(0, device=device))

return head.any(dim=-1).sum()