From b48e807de7dd3309801af272dd91014a34994d15 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 28 Mar 2023 10:24:32 +0800 Subject: [PATCH] Fix failing TPU tests --- keras_nlp/models/t5/t5_backbone_test.py | 2 +- keras_nlp/models/whisper/whisper_backbone_test.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_nlp/models/t5/t5_backbone_test.py b/keras_nlp/models/t5/t5_backbone_test.py index 53c283959e..813052b11e 100644 --- a/keras_nlp/models/t5/t5_backbone_test.py +++ b/keras_nlp/models/t5/t5_backbone_test.py @@ -136,7 +136,7 @@ def setUp(self): ).batch(2) def test_predict(self): - self.model.compile() + self.backbone.compile() outputs = self.backbone.predict(self.input_dataset) self.assertIn("encoder_sequence_output", outputs) self.assertIn("decoder_sequence_output", outputs) diff --git a/keras_nlp/models/whisper/whisper_backbone_test.py b/keras_nlp/models/whisper/whisper_backbone_test.py index b71bbfdaf0..456cb7b564 100644 --- a/keras_nlp/models/whisper/whisper_backbone_test.py +++ b/keras_nlp/models/whisper/whisper_backbone_test.py @@ -142,16 +142,16 @@ def setUp(self): "encoder_features": tf.ones( ( 8, - self.model.max_encoder_sequence_length, + self.backbone.max_encoder_sequence_length, NUM_MELS, ), dtype="int32", ), "decoder_token_ids": tf.ones( - (8, self.model.max_decoder_sequence_length), dtype="int32" + (8, self.backbone.max_decoder_sequence_length), dtype="int32" ), "decoder_padding_mask": tf.ones( - (8, self.model.max_decoder_sequence_length), dtype="int32" + (8, self.backbone.max_decoder_sequence_length), dtype="int32" ), }