-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/make transformer decoder callable without causal mask #1083
Feat/make transformer decoder callable without causal mask #1083
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did not add a test specifically for that since the current testing strategy does not test any of the other "_mask" arguments to the call() function. Since this argument is at the same abstraction level I decided to follow the convention
use_causal_mask: bool, defaults to True. If true, causal mask | ||
(masking out future input) is applied on the decoder sequence. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming and documentation from previous version where this was in: https://github.com/keras-team/keras-nlp/blob/cb0fa028971475879911ddf042a1473037775ee6/keras_nlp/layers/transformer_decoder.py#L191-L192
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this lgtm, and matches the multi-head attention layer in core Keras
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great thanks for the contribution!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thank you. Few minor comments.
use_causal_mask: bool, defaults to True. If true, causal mask | ||
(masking out future input) is applied on the decoder sequence. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this lgtm, and matches the multi-head attention layer in core Keras
if decoder_mask is not None: | ||
self_attention_mask = tf.minimum(decoder_mask, self_attention_mask) | ||
if use_causal_mask: | ||
batch_size = tf.shape(decoder_sequence)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this branching is getting complex enough that we should split a private method for this
self_attention_mask = self._compute_self_attention_mask(
decoder_sequence,
self_attention_cache,
self_attention_cache_update_index,
use_causal_mask,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, done
/gcbrun |
Fixes #1062