diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 1e342e791c..09053ff893 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -230,6 +230,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 9c8cdaa60e..867616da69 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -28,6 +28,15 @@ def __init__(self, *args, dtype=None, **kwargs): id(layer) for layer in self._flatten_layers() ) self._initialized = True + if dtype is not None: + # Keras 2 and Keras 3 handle setting policy differently. + if config.keras_3(): + if isinstance(dtype, keras.DTypePolicy): + self.dtype_policy = dtype + else: + self.dtype_policy = keras.DTypePolicy(dtype) + else: + self._set_dtype_policy(dtype) def __dir__(self): if config.keras_3(): @@ -67,7 +76,7 @@ def token_embedding(self): This layer embeds integer token ids to the hidden dim of the model. """ - return self._token_embedding + return getattr(self, "_token_embedding", None) @token_embedding.setter def token_embedding(self, value): diff --git a/keras_nlp/models/bart/bart_backbone.py b/keras_nlp/models/bart/bart_backbone.py index 803d5a2a9f..f100133d25 100644 --- a/keras_nlp/models/bart/bart_backbone.py +++ b/keras_nlp/models/bart/bart_backbone.py @@ -232,6 +232,7 @@ def __init__( "encoder_sequence_output": encoder_output, "decoder_sequence_output": decoder_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index 2248260da7..320dc1c2ee 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -196,6 +196,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py index 5737dcc889..5c6f81ca5b 100644 --- a/keras_nlp/models/bloom/bloom_backbone.py +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -149,6 +149,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py index e7bd8ca20a..9063b11df5 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py @@ -178,6 +178,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone.py b/keras_nlp/models/distil_bert/distil_bert_backbone.py index 1ae0840ea8..73634b4216 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone.py @@ -159,6 +159,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index 13be2d8eb8..f4f2a23b69 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -202,6 +202,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/f_net/f_net_backbone.py b/keras_nlp/models/f_net/f_net_backbone.py index 309f312a17..ab056c84c7 100644 --- a/keras_nlp/models/f_net/f_net_backbone.py +++ b/keras_nlp/models/f_net/f_net_backbone.py @@ -206,6 +206,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/falcon/falcon_backbone.py b/keras_nlp/models/falcon/falcon_backbone.py index 4951189fe0..5a3a0fccda 100644 --- a/keras_nlp/models/falcon/falcon_backbone.py +++ b/keras_nlp/models/falcon/falcon_backbone.py @@ -130,6 +130,7 @@ def __init__( "padding_mask": padding_mask, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py index e5814940aa..c829aa948f 100644 --- a/keras_nlp/models/gemma/gemma_backbone.py +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -157,6 +157,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index d93b2199b0..b7d2b10acf 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -170,6 +170,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py index 1955ed5801..415fa56af2 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py @@ -137,6 +137,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index cc628ad7a5..733d9ef434 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -127,6 +127,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/mistral/mistral_backbone.py b/keras_nlp/models/mistral/mistral_backbone.py index 3e2cfae148..52de945760 100644 --- a/keras_nlp/models/mistral/mistral_backbone.py +++ b/keras_nlp/models/mistral/mistral_backbone.py @@ -166,6 +166,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/opt/opt_backbone.py b/keras_nlp/models/opt/opt_backbone.py index 0b98a6c64e..16fe4a0218 100644 --- a/keras_nlp/models/opt/opt_backbone.py +++ b/keras_nlp/models/opt/opt_backbone.py @@ -146,6 +146,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index 1ab61eeeb7..09fe753762 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -156,6 +156,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index cf747c503c..862c4766f4 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -224,6 +224,7 @@ def __init__( "encoder_sequence_output": encoder_output, "decoder_sequence_output": decoder_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 783cc0b41b..0656d2194e 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -36,6 +36,12 @@ def __init__(self, *args, **kwargs): id(layer) for layer in self._flatten_layers() ) self._initialized = True + if self.backbone is not None: + # Keras 2 and Keras 3 handle setting policy differently. + if config.keras_3(): + self.dtype_policy = self._backbone.dtype_policy + else: + self._set_dtype_policy(self._backbone.dtype_policy) def __dir__(self): if config.keras_3(): @@ -128,7 +134,7 @@ def __setattr__(self, name, value): @property def backbone(self): """A `keras.Model` instance providing the backbone sub-model.""" - return self._backbone + return getattr(self, "_backbone", None) @backbone.setter def backbone(self, value): @@ -137,7 +143,7 @@ def backbone(self, value): @property def preprocessor(self): """A `keras.layers.Layer` instance used to preprocess inputs.""" - return self._preprocessor + return getattr(self, "_preprocessor", None) @preprocessor.setter def preprocessor(self, value): diff --git a/keras_nlp/models/whisper/whisper_backbone.py b/keras_nlp/models/whisper/whisper_backbone.py index c66a61d4e5..a2b685544e 100644 --- a/keras_nlp/models/whisper/whisper_backbone.py +++ b/keras_nlp/models/whisper/whisper_backbone.py @@ -274,6 +274,7 @@ def __init__( "encoder_sequence_output": encoder_output, "decoder_sequence_output": decoder_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/xlnet/xlnet_backbone.py b/keras_nlp/models/xlnet/xlnet_backbone.py index 0d660bead9..45be1f74e7 100644 --- a/keras_nlp/models/xlnet/xlnet_backbone.py +++ b/keras_nlp/models/xlnet/xlnet_backbone.py @@ -184,6 +184,7 @@ def __init__( "segment_ids": segment_id_input, }, outputs=output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/tests/test_case.py b/keras_nlp/tests/test_case.py index 0541ae6451..6b88757c64 100644 --- a/keras_nlp/tests/test_case.py +++ b/keras_nlp/tests/test_case.py @@ -329,10 +329,10 @@ def run_precision_test(self, cls, init_kwargs, input_data): for weight in layer.weights: if is_float_dtype(weight.dtype): self.assertDTypeEqual(weight, policy.variable_dtype) - for sublayer in layer._flatten_layers(include_self=False): - if isinstance( - sublayer, (keras.layers.Softmax, keras.layers.InputLayer) - ): + for sublayer in layer._flatten_layers(): + if isinstance(sublayer, keras.layers.Softmax): + continue + if isinstance(sublayer, keras.layers.InputLayer): continue self.assertEqual(policy.compute_dtype, sublayer.compute_dtype) self.assertEqual(policy.variable_dtype, sublayer.variable_dtype)