Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The attention scores are always None in CachedMultiHeadAttention #2055

Open
apehex opened this issue Jan 23, 2025 · 0 comments
Open

The attention scores are always None in CachedMultiHeadAttention #2055

apehex opened this issue Jan 23, 2025 · 0 comments
Assignees

Comments

@apehex
Copy link
Contributor

apehex commented Jan 23, 2025

Describe the bug

The variable attention_scores introduced at line 111 is always None.

To Reproduce

Since it is an internal variable, I copied the subclass CMHA in this script:
https://colab.research.google.com/drive/1ZUS4mjDQktovKiJ8TQ7zYtm4PGjesXvG?usp=sharing

Expected behavior

The variable attention_scores should contain the cross correlation between query and key, which is useful for debugging a model IMHO.

Additional context

In recent Keras versions, the parent class MultiHeadAttention saves the argument return_attention_scores in self._return_attention_scores.

Then, the method _compute_attention checks this private property to decide whether or not to return the scores.
Since this state is not updated in CachedMultiHeadAttention.call, the attention scores will never be returned.

I'll also submit an issue to Keras to turn the attribute _return_attention_scores into an argument.

Would you like to help us fix it?

Yes, I have two potential fixes:

  1. ignore attention scores entirely, which is consistent since the corresponding argument has been removed from CMHA
  2. add the relevant argument and set the class property _return_attention_scores accordingly

WDYT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants