Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sliding window fixes #1738

Merged
merged 2 commits into from
Aug 6, 2024
Merged

Conversation

mattdangerw
Copy link
Member

@mattdangerw mattdangerw commented Aug 6, 2024

We have an issue, particularly on the tensorflow backend, when computing the sliding window mask during generation.

  • For tf, this would affect any sequence length.
  • For jax and torch, this would only affect generations longer 4096.

Before: https://colab.research.google.com/gist/mattdangerw/3d7ab7fd0f2a1169e67d3f4d43d40701/keras-tf-bug.ipynb
After: https://colab.research.google.com/gist/mattdangerw/b48e47107a2513c61d8e70a4652df468/keras-tf-bug-with-fix.ipynb

@mattdangerw mattdangerw requested review from SamanehSaadat and grasskin and removed request for grasskin August 6, 2024 00:20
@mattdangerw mattdangerw force-pushed the sliding-window-fixes branch from a2ccea3 to 1c92bb7 Compare August 6, 2024 00:43
Copy link
Member

@SamanehSaadat SamanehSaadat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix, Matt!

Are these the two main bugs?

  • It's been assumed that key_len==query_len.
  • Caching hasn't been handled.

Copy link
Member

@grasskin grasskin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thank you for this fix!

@mattdangerw
Copy link
Member Author

@SamanehSaadat

It's been assumed that key_len==query_len.
Caching hasn't been handled.

Kind of? Tensorflow was taking the min(query_len, sliding_window_size) as the effective sliding window size, which was basically turning the model into something that could only look 1 token behind. The general shape for generation is query_len=1, key_len=max_length.

And no backend was taking the index of generation (the cache index), to make sure our sliding window was correct for our current position.

@mattdangerw mattdangerw merged commit 94283d6 into keras-team:master Aug 6, 2024
10 checks passed
mattdangerw added a commit that referenced this pull request Aug 6, 2024
* Add tests for sliding window issues

* Fix for sliding window issues
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants