-
Notifications
You must be signed in to change notification settings - Fork 251
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove sliding window attention from Mistral's attention layer
JAX complains about dynamic slicing when compiled with XLA. This is unavoidable since, at runtime, the slice of the current key/value array to use for that iteration is determined by `cache_update_index` which is itself a JAX `TracedArray`. Any workaround would lead to using dynamic shapes at some point. Hence, I had to remove this and instead use vanilla caching for now. For some reason, TensorFlow doesn't complain with XLA. I think this might be because TensorFlow is as stringent about statis shapes as JAX. In any case, adding sliding window attention that is XLA compatible is a story for the future.
- Loading branch information
1 parent
2e2e2e5
commit 19b0b89
Showing
1 changed file
with
17 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters