diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 04ba337ef436..f5967e5e803b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -858,8 +858,12 @@ def update( k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states + # `index_copy_(dim, index, source)` functions similarly to `tensor[index] = source`, + # but it is used for better generality and flexibility. + # For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html + + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) return k_out, v_out @@ -868,7 +872,17 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. # TODO: deprecate this function in favor of `cache_position` - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + key_cache = self.key_cache[layer_idx] + device = key_cache.device + + # index_select(dim, index) performs the same operation as item = tensor[..., index, ...] + # but it is used for better generality and flexibility. + # For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html + + item = key_cache.index_select(0, torch.tensor(0, device=device)) + head = item.index_select(1, torch.tensor(0, device=device)) + + return head.any(dim=-1).sum() def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states."""