Skip to content

Commit

Permalink
Fix dtype accessors of tasks/backbones (keras-team#1486)
Browse files Browse the repository at this point in the history
* Fix dtype accessors of tasks/backbones

* Address comments, minor fixes
  • Loading branch information
mattdangerw authored and abuelnasr0 committed Apr 2, 2024
1 parent 7f692ca commit 8851624
Show file tree
Hide file tree
Showing 22 changed files with 41 additions and 7 deletions.
1 change: 1 addition & 0 deletions keras_nlp/models/albert/albert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def __init__(
"sequence_output": sequence_output,
"pooled_output": pooled_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
11 changes: 10 additions & 1 deletion keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ def __init__(self, *args, dtype=None, **kwargs):
id(layer) for layer in self._flatten_layers()
)
self._initialized = True
if dtype is not None:
# Keras 2 and Keras 3 handle setting policy differently.
if config.keras_3():
if isinstance(dtype, keras.DTypePolicy):
self.dtype_policy = dtype
else:
self.dtype_policy = keras.DTypePolicy(dtype)
else:
self._set_dtype_policy(dtype)

def __dir__(self):
if config.keras_3():
Expand Down Expand Up @@ -67,7 +76,7 @@ def token_embedding(self):
This layer embeds integer token ids to the hidden dim of the model.
"""
return self._token_embedding
return getattr(self, "_token_embedding", None)

@token_embedding.setter
def token_embedding(self, value):
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/bart/bart_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/bert/bert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(
"sequence_output": sequence_output,
"pooled_output": pooled_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/bloom/bloom_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/deberta_v3/deberta_v3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=x,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/distil_bert/distil_bert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=x,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/electra/electra_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(
"sequence_output": sequence_output,
"pooled_output": pooled_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/f_net/f_net_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
"sequence_output": sequence_output,
"pooled_output": pooled_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/falcon/falcon_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
"padding_mask": padding_mask,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/gpt2/gpt2_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/mistral/mistral_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/opt/opt_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=x,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/roberta/roberta_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=x,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/t5/t5_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def __init__(
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
10 changes: 8 additions & 2 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def __init__(self, *args, **kwargs):
id(layer) for layer in self._flatten_layers()
)
self._initialized = True
if self.backbone is not None:
# Keras 2 and Keras 3 handle setting policy differently.
if config.keras_3():
self.dtype_policy = self._backbone.dtype_policy
else:
self._set_dtype_policy(self._backbone.dtype_policy)

def __dir__(self):
if config.keras_3():
Expand Down Expand Up @@ -128,7 +134,7 @@ def __setattr__(self, name, value):
@property
def backbone(self):
"""A `keras.Model` instance providing the backbone sub-model."""
return self._backbone
return getattr(self, "_backbone", None)

@backbone.setter
def backbone(self, value):
Expand All @@ -137,7 +143,7 @@ def backbone(self, value):
@property
def preprocessor(self):
"""A `keras.layers.Layer` instance used to preprocess inputs."""
return self._preprocessor
return getattr(self, "_preprocessor", None)

@preprocessor.setter
def preprocessor(self, value):
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/whisper/whisper_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/xlnet/xlnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __init__(
"segment_ids": segment_id_input,
},
outputs=output,
dtype=dtype,
**kwargs,
)

Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,10 @@ def run_precision_test(self, cls, init_kwargs, input_data):
for weight in layer.weights:
if is_float_dtype(weight.dtype):
self.assertDTypeEqual(weight, policy.variable_dtype)
for sublayer in layer._flatten_layers(include_self=False):
if isinstance(
sublayer, (keras.layers.Softmax, keras.layers.InputLayer)
):
for sublayer in layer._flatten_layers():
if isinstance(sublayer, keras.layers.Softmax):
continue
if isinstance(sublayer, keras.layers.InputLayer):
continue
self.assertEqual(policy.compute_dtype, sublayer.compute_dtype)
self.assertEqual(policy.variable_dtype, sublayer.variable_dtype)
Expand Down

0 comments on commit 8851624

Please sign in to comment.