Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/make transformer decoder callable without causal mask #1083

Conversation

ferraric
Copy link
Contributor

@ferraric ferraric commented Jun 17, 2023

Fixes #1062

@google-cla
Copy link

google-cla bot commented Jun 17, 2023

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Contributor Author

@ferraric ferraric left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did not add a test specifically for that since the current testing strategy does not test any of the other "_mask" arguments to the call() function. Since this argument is at the same abstraction level I decided to follow the convention

Comment on lines 255 to 256
use_causal_mask: bool, defaults to True. If true, causal mask
(masking out future input) is applied on the decoder sequence.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this lgtm, and matches the multi-head attention layer in core Keras

keras_nlp/layers/transformer_decoder.py Show resolved Hide resolved
@ferraric ferraric marked this pull request as ready for review June 17, 2023 10:51
@jbischof jbischof requested a review from mattdangerw June 17, 2023 14:31
Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great thanks for the contribution!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thank you. Few minor comments.

keras_nlp/layers/transformer_decoder.py Show resolved Hide resolved
keras_nlp/layers/transformer_decoder.py Outdated Show resolved Hide resolved
Comment on lines 255 to 256
use_causal_mask: bool, defaults to True. If true, causal mask
(masking out future input) is applied on the decoder sequence.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this lgtm, and matches the multi-head attention layer in core Keras

if decoder_mask is not None:
self_attention_mask = tf.minimum(decoder_mask, self_attention_mask)
if use_causal_mask:
batch_size = tf.shape(decoder_sequence)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this branching is getting complex enough that we should split a private method for this

self_attention_mask = self._compute_self_attention_mask(
    decoder_sequence,
    self_attention_cache,
    self_attention_cache_update_index,
    use_causal_mask,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, done

keras_nlp/layers/transformer_decoder.py Show resolved Hide resolved
@mattdangerw
Copy link
Member

/gcbrun

@mattdangerw mattdangerw merged commit 71a6a2a into keras-team:master Jun 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make causal mask in TransformerDecoder optional
3 participants