Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/make transformer decoder callable without causal mask #1083

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions keras_nlp/layers/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TransformerDecoder(keras.layers.Layer):
paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
can instantiate multiple instances of this class to stack up a decoder.

This layer will always apply a causal mask to the decoder attention layer.
jbischof marked this conversation as resolved.
Show resolved Hide resolved
By default, this layer will apply a causal mask to the decoder attention layer.
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
This layer will correctly compute an attention mask from an implicit
Keras padding mask (for example, by passing `mask_zero=True` to a
`keras.layers.Embedding` layer). See the Masking and Padding
Expand Down Expand Up @@ -218,6 +218,7 @@ def call(
self_attention_cache_update_index=None,
cross_attention_cache=None,
cross_attention_cache_update_index=None,
use_causal_mask=True,
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
):
"""Forward pass of the TransformerDecoder.

Expand Down Expand Up @@ -252,6 +253,8 @@ def call(
at which to update the `cross_attention_cache`. Usually, this is
either `0` (compute the entire `cross_attention_cache`), or
`None` (reuse a previously computed `cross_attention_cache`).
use_causal_mask: bool, defaults to `True`. If true, a causal mask
(masking out future input) is applied `on the decoder sequence.
Returns:
One of three things, depending on call arguments:
- `outputs`, if `self_attention_cache` is `None.
Expand Down Expand Up @@ -305,27 +308,14 @@ def call(

x = decoder_sequence # Intermediate result.

# Compute self attention mask.
batch_size = tf.shape(decoder_sequence)[0]
input_length = output_length = tf.shape(decoder_sequence)[1]
# We need to handle a rectangular causal mask when doing cached
# decoding. For generative inference, `decoder_sequence` will
# generally be length 1, and `cache` will be the full generation length.
if self_attention_cache is not None:
input_length = tf.shape(self_attention_cache)[2]
self_attention_mask = compute_causal_mask(
batch_size,
input_length,
output_length,
0
if self_attention_cache_update_index is None
else self_attention_cache_update_index,
)
decoder_mask = merge_padding_and_attention_mask(
decoder_sequence, decoder_padding_mask, decoder_attention_mask
self_attention_mask = self._compute_self_attention_mask(
decoder_sequence=decoder_sequence,
decoder_padding_mask=decoder_padding_mask,
decoder_attention_mask=decoder_attention_mask,
use_causal_mask=use_causal_mask,
self_attention_cache=self_attention_cache,
self_attention_cache_update_index=self_attention_cache_update_index,
)
if decoder_mask is not None:
self_attention_mask = tf.minimum(decoder_mask, self_attention_mask)

# Self attention block.
residual = x
Expand Down Expand Up @@ -385,6 +375,42 @@ def call(
else:
return x

def _compute_self_attention_mask(
self,
decoder_sequence,
decoder_padding_mask,
decoder_attention_mask,
use_causal_mask,
self_attention_cache,
self_attention_cache_update_index,
):
decoder_mask = merge_padding_and_attention_mask(
decoder_sequence, decoder_padding_mask, decoder_attention_mask
)
if use_causal_mask:
batch_size = tf.shape(decoder_sequence)[0]
input_length = output_length = tf.shape(decoder_sequence)[1]
# We need to handle a rectangular causal mask when doing cached
# decoding. For generative inference, `decoder_sequence` will
# generally be length 1, and `cache` will be the full generation length.
if self_attention_cache is not None:
input_length = tf.shape(self_attention_cache)[2]

causal_mask = compute_causal_mask(
batch_size,
input_length,
output_length,
0
if self_attention_cache_update_index is None
else self_attention_cache_update_index,
)
return (
tf.minimum(decoder_mask, causal_mask)
if decoder_mask is not None
else causal_mask
)
return decoder_mask

def get_config(self):
config = super().get_config()
config.update(
Expand Down