diff --git a/README.md b/README.md index 7d01c2cb9d..4d5073fc6d 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ KerasNLP provides access to pre-trained models via the `keras_nlp.models` API. These pre-trained models are provided on an "as is" basis, without warranties or conditions of any kind. The following underlying models are provided by third parties, and subject to separate licenses: -BART, DeBERTa, DistilBERT, GPT-2, OPT, RoBERTa, and XLM-RoBERTa. +BART, DeBERTa, DistilBERT, GPT-2, OPT, RoBERTa, Whisper, and XLM-RoBERTa. ## Citing KerasNLP diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index ceac30391a..dda0965b44 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -83,6 +83,7 @@ ) from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer +from keras_nlp.models.whisper.whisper_backbone import WhisperBackbone from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone from keras_nlp.models.xlm_roberta.xlm_roberta_classifier import ( XLMRobertaClassifier, diff --git a/keras_nlp/models/whisper/__init__.py b/keras_nlp/models/whisper/__init__.py new file mode 100644 index 0000000000..ba0c2545e4 --- /dev/null +++ b/keras_nlp/models/whisper/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/models/whisper/whisper_backbone.py b/keras_nlp/models/whisper/whisper_backbone.py new file mode 100644 index 0000000000..83b9324010 --- /dev/null +++ b/keras_nlp/models/whisper/whisper_backbone.py @@ -0,0 +1,293 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Whisper backbone model.""" + + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.position_embedding import PositionEmbedding +from keras_nlp.layers.token_and_position_embedding import ( + TokenAndPositionEmbedding, +) +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.whisper.whisper_decoder import WhisperDecoder +from keras_nlp.models.whisper.whisper_encoder import WhisperEncoder + +# We hardcode the number of mel-frequency filters: +# https://github.com/openai/whisper/blob/v20230124/whisper/audio.py#L101-L102. +# TODO: If needed, we can make it configurable. +NUM_MELS = 80 + + +def whisper_kernel_initializer(stddev=0.02): + return keras.initializers.TruncatedNormal(stddev=stddev) + + +@keras_nlp_export("keras_nlp.models.WhisperBackbone") +class WhisperBackbone(Backbone): + """Whisper encoder-decoder network for speech. + + This class implements a Transformer-based encoder-decoder model as + described in + ["Robust Speech Recognition via Large-Scale Weak Supervision"](https://arxiv.org/abs/2212.04356). + It includes the embedding lookups and transformer layers, but not the head + for predicting the next token. + + The default constructor gives a fully customizable, randomly initialized Whisper + model with any number of layers, heads, and embedding dimensions. To load + preset architectures and weights, use the `from_preset` constructor. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. The underlying model is provided by a + third party and subject to a separate license, available + [here](https://github.com/openai/whisper). + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer encoder layers and + transformer decoder layers. + num_heads: int. The number of attention heads for each transformer. + The hidden size must be divisible by the number of attention heads. + hidden_dim: int. The size of the transformer encoding and pooler layers. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + dropout: float. Dropout probability for the Transformer encoder. + max_encoder_sequence_length: int. The maximum sequence length that the + audio encoder can consume. Since the second convolutional layer in + the encoder reduces the sequence length by half (stride of 2), we + use `max_encoder_sequence_length // 2` as the sequence length for the + positional embedding layer. + max_decoder_sequence_length: int. The maximum sequence length that the + text decoder can consume. + + Examples: + ```python + input_data = { + "encoder_token_ids": tf.ones(shape=(1, 12, 80), dtype=tf.int64), + "decoder_token_ids": tf.ones(shape=(1, 12), dtype=tf.int64), + "decoder_padding_mask": tf.constant( + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], shape=(1, 12) + ), + } + + # Randomly initialized Whisper encoder-decoder model with a custom config + model = keras_nlp.models.WhisperBackbone( + vocabulary_size=51864, + num_layers=6, + num_heads=8, + hidden_dim=512, + intermediate_dim=2048, + max_encoder_sequence_length=128, + max_decoder_sequence_length=64, + ) + output = model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_heads, + hidden_dim, + intermediate_dim, + dropout=0.0, + max_encoder_sequence_length=3000, + max_decoder_sequence_length=448, + **kwargs, + ): + # Encoder inputs. Note that the encoder does not have a padding mask: + # https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L132. + encoder_feature_input = keras.Input( + shape=(None, NUM_MELS), dtype="float32", name="encoder_features" + ) + + # Decoder inputs. + decoder_token_id_input = keras.Input( + shape=(None,), dtype="int32", name="decoder_token_ids" + ) + decoder_padding_mask = keras.Input( + shape=(None,), dtype="int32", name="decoder_padding_mask" + ) + + # ====== Encoder ====== + + # Embed the input features. This consists of two 1D convolutional + # layers. + # For the first layer, we use `padding="same"` since that corresponds to + # a padding size of 1. + encoder_conv_layer_1 = keras.layers.Conv1D( + filters=hidden_dim, + kernel_size=3, + strides=1, + padding="same", + name="encoder_token_embedding_conv_layer_1", + ) + embedded_features = keras.activations.gelu( + encoder_conv_layer_1(encoder_feature_input), + approximate=False, + ) + + # For the second conv. layer, we cannot use `padding="same"` since + # that corresponds to a padding size of 1.5 (since stride is 2). Hence, + # we will manually pad the input. + embedded_features = tf.pad( + embedded_features, paddings=[[0, 0], [1, 1], [0, 0]] + ) + encoder_conv_layer_2 = keras.layers.Conv1D( + filters=hidden_dim, + kernel_size=3, + strides=2, + padding="valid", + name="encoder_token_embedding_conv_layer_2", + ) + embedded_features = keras.activations.gelu( + encoder_conv_layer_2(embedded_features), + approximate=False, + ) + + # The position embedding layer for the encoder is a sinusoidal embedding + # layer: https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L137. + # Hence, we set it to be non-trainable. + # TODO: We can use `keras_nlp.layers.SinePositionEncoding` layer. + position_embedding = PositionEmbedding( + initializer=whisper_kernel_initializer(), + sequence_length=max_encoder_sequence_length // 2, + name="encoder_position_embedding", + trainable=False, + )(embedded_features) + + # Sum and apply dropout to embeddings. + x = keras.layers.Add()((embedded_features, position_embedding)) + x = keras.layers.Dropout( + dropout, + name="encoder_embeddings_dropout", + )(x) + + # Apply successive transformer encoder blocks. + for i in range(num_layers): + x = WhisperEncoder( + num_heads=num_heads, + intermediate_dim=intermediate_dim, + activation=lambda x: keras.activations.gelu( + x, approximate=False + ), + layer_norm_epsilon=1e-5, + dropout=dropout, + kernel_initializer=whisper_kernel_initializer(), + normalize_first=True, + name=f"transformer_encoder_layer_{i}", + )(x) + + x = keras.layers.LayerNormalization( + name="encoder_layer_norm", + axis=-1, + epsilon=1e-5, + dtype=tf.float32, + )(x) + encoder_output = x + + # ====== Decoder ====== + + # Embed tokens and positions. + x = TokenAndPositionEmbedding( + vocabulary_size=vocabulary_size, + sequence_length=max_decoder_sequence_length, + embedding_dim=hidden_dim, + embeddings_initializer=whisper_kernel_initializer(), + name="decoder_token_and_position_embedding", + )(decoder_token_id_input) + + # Apply dropout to embeddings. + x = keras.layers.Dropout( + dropout, + name="decoder_embeddings_dropout", + )(x) + + # Apply successive transformer decoder blocks. + for i in range(num_layers): + transformer_decoder_layer = WhisperDecoder( + intermediate_dim=intermediate_dim, + num_heads=num_heads, + dropout=dropout, + activation=lambda x: keras.activations.gelu( + x, approximate=False + ), + layer_norm_epsilon=1e-5, + kernel_initializer=whisper_kernel_initializer(), + normalize_first=True, + name=f"transformer_decoder_layer_{i}", + has_cross_attention=True, + ) + x = transformer_decoder_layer( + decoder_sequence=x, + encoder_sequence=encoder_output, + decoder_padding_mask=decoder_padding_mask, + ) + + x = keras.layers.LayerNormalization( + name="decoder_layer_norm", + axis=-1, + epsilon=1e-5, + dtype=tf.float32, + )(x) + decoder_output = x + + # Instantiate using Functional API Model constructor + super().__init__( + inputs={ + "encoder_features": encoder_feature_input, + "decoder_token_ids": decoder_token_id_input, + "decoder_padding_mask": decoder_padding_mask, + }, + outputs={ + "encoder_sequence_output": encoder_output, + "decoder_sequence_output": decoder_output, + }, + **kwargs, + ) + + # All references to `self` below this line + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.dropout = dropout + self.max_encoder_sequence_length = max_encoder_sequence_length + self.max_decoder_sequence_length = max_decoder_sequence_length + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, + "max_encoder_sequence_length": self.max_encoder_sequence_length, + "max_decoder_sequence_length": self.max_decoder_sequence_length, + } + ) + return config + + @property + def token_embedding(self): + return self.get_layer( + "decoder_token_and_position_embedding" + ).token_embedding diff --git a/keras_nlp/models/whisper/whisper_backbone_test.py b/keras_nlp/models/whisper/whisper_backbone_test.py new file mode 100644 index 0000000000..df4f25742d --- /dev/null +++ b/keras_nlp/models/whisper/whisper_backbone_test.py @@ -0,0 +1,179 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for Whisper backbone models.""" + +import os + +import pytest +import tensorflow as tf +from absl.testing import parameterized +from tensorflow import keras + +from keras_nlp.models.whisper.whisper_backbone import NUM_MELS +from keras_nlp.models.whisper.whisper_backbone import WhisperBackbone + + +class WhisperBackboneTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + self.model = WhisperBackbone( + vocabulary_size=1000, + num_layers=2, + num_heads=2, + hidden_dim=64, + intermediate_dim=128, + max_encoder_sequence_length=128, + max_decoder_sequence_length=96, + ) + self.batch_size = 8 + self.input_batch = { + "encoder_features": tf.ones( + ( + self.batch_size, + self.model.max_encoder_sequence_length, + NUM_MELS, + ), + dtype="int32", + ), + "decoder_token_ids": tf.ones( + (self.batch_size, self.model.max_decoder_sequence_length), + dtype="int32", + ), + "decoder_padding_mask": tf.ones( + (self.batch_size, self.model.max_decoder_sequence_length), + dtype="int32", + ), + } + + self.input_dataset = tf.data.Dataset.from_tensor_slices( + self.input_batch + ).batch(2) + + def test_valid_call_whisper(self): + self.model(self.input_batch) + + # Check default name passed through + self.assertRegexpMatches(self.model.name, "whisper_backbone") + + def test_variable_sequence_length_call_whisper(self): + for seq_length in (25, 50, 75): + input_data = { + "encoder_features": tf.ones( + (self.batch_size, seq_length, NUM_MELS), + dtype="int32", + ), + "decoder_token_ids": tf.ones( + (self.batch_size, seq_length), dtype="int32" + ), + "decoder_padding_mask": tf.ones( + (self.batch_size, seq_length), dtype="int32" + ), + } + self.model(input_data) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_compile(self, jit_compile): + self.model.compile(jit_compile=jit_compile) + self.model.predict(self.input_batch) + + @parameterized.named_parameters( + ("jit_compile_false", False), ("jit_compile_true", True) + ) + def test_compile_batched_ds(self, jit_compile): + self.model.compile(jit_compile=jit_compile) + self.model.predict(self.input_dataset) + + def test_key_projection_bias_absence(self): + # Check only for the first encoder layer and first decoder layer. + self.assertIsNone( + self.model.get_layer( + "transformer_encoder_layer_0" + )._self_attention_layer._key_dense.bias + ) + self.assertIsNone( + self.model.get_layer( + "transformer_decoder_layer_0" + )._self_attention_layer._key_dense.bias + ) + self.assertIsNone( + self.model.get_layer( + "transformer_decoder_layer_0" + )._cross_attention_layer._key_dense.bias + ) + + @parameterized.named_parameters( + ("tf_format", "tf", "model"), + ("keras_format", "keras_v3", "model.keras"), + ) + def test_saved_model(self, save_format, filename): + model_output = self.model(self.input_batch) + save_path = os.path.join(self.get_temp_dir(), filename) + self.model.save(save_path, save_format=save_format) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, WhisperBackbone) + + # Check that output matches. + restored_output = restored_model(self.input_batch) + self.assertAllClose( + model_output["encoder_sequence_output"], + restored_output["encoder_sequence_output"], + ) + self.assertAllClose( + model_output["decoder_sequence_output"], + restored_output["decoder_sequence_output"], + ) + + +@pytest.mark.tpu +@pytest.mark.usefixtures("tpu_test_class") +class WhisperBackboneTPUTest(tf.test.TestCase, parameterized.TestCase): + def setUp(self): + with self.tpu_strategy.scope(): + self.model = WhisperBackbone( + vocabulary_size=1000, + num_layers=2, + num_heads=2, + hidden_dim=64, + intermediate_dim=128, + max_encoder_sequence_length=128, + max_decoder_sequence_length=64, + ) + + self.input_batch = { + "encoder_features": tf.ones( + ( + 8, + self.model.max_encoder_sequence_length, + NUM_MELS, + ), + dtype="int32", + ), + "decoder_token_ids": tf.ones( + (8, self.model.max_decoder_sequence_length), dtype="int32" + ), + "decoder_padding_mask": tf.ones( + (8, self.model.max_decoder_sequence_length), dtype="int32" + ), + } + + self.input_dataset = tf.data.Dataset.from_tensor_slices( + self.input_batch + ).batch(2) + + def test_predict(self): + self.model.compile() + self.model.predict(self.input_dataset) diff --git a/keras_nlp/models/whisper/whisper_decoder.py b/keras_nlp/models/whisper/whisper_decoder.py new file mode 100644 index 0000000000..d7d7d2867e --- /dev/null +++ b/keras_nlp/models/whisper/whisper_decoder.py @@ -0,0 +1,36 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Whisper decoder block.""" + +from tensorflow import keras + +from keras_nlp.layers.transformer_decoder import TransformerDecoder + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class WhisperDecoder(TransformerDecoder): + """Whisper decoder. + + Inherits from `keras_nlp.layers.TransformerDecoder`, and overrides the + `_build` method so as to remove the bias term from the key projection layer. + """ + + def _build(self, input_shape, has_cross_attention): + super()._build(input_shape, has_cross_attention) + + # Since there is no exposed option for this in MHA, we will reach into + # the internals of the layer for now. + self._self_attention_layer._key_dense.bias_axes = None + if has_cross_attention: + self._cross_attention_layer._key_dense.bias_axes = None diff --git a/keras_nlp/models/whisper/whisper_encoder.py b/keras_nlp/models/whisper/whisper_encoder.py new file mode 100644 index 0000000000..ec1040891c --- /dev/null +++ b/keras_nlp/models/whisper/whisper_encoder.py @@ -0,0 +1,34 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Whisper encoder block.""" + +from tensorflow import keras + +from keras_nlp.layers.transformer_encoder import TransformerEncoder + + +@keras.utils.register_keras_serializable(package="keras_nlp") +class WhisperEncoder(TransformerEncoder): + """Whisper encoder. + + Inherits from `keras_nlp.layers.TransformerEncoder`, and overrides the + `_build` method so as to remove the bias term from the key projection layer. + """ + + def _build(self, input_shape): + super()._build(input_shape) + + # Since there is no exposed option for this in MHA, we will reach into + # the internals of the layer for now. + self._self_attention_layer._key_dense.bias_axes = None