Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype accessors of tasks/backbones #1486

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/models/albert/albert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def __init__(
"sequence_output": sequence_output,
"pooled_output": pooled_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
11 changes: 10 additions & 1 deletion keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ def __init__(self, *args, dtype=None, **kwargs):
id(layer) for layer in self._flatten_layers()
)
self._initialized = True
if dtype is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this fix, I have one comment about setting the dtype. if you would like to give it a look, I will appreciate it.
I think we should follow how keras.layer sets the dtype police.
something like this:

if config.keras_3():
    self.dtype_policy = keras.src.dtype_policies.get(dtype)
else:
    self._set_dtype_policy(dtype)

I think that will be better user experience as whatever dtype works for layer should be expected to work for the backbone, and for keras3 for example, keras.src.dtype_policies.get(dtype) covers more scenarios.
also keras2 do some extra stuff after setting the dtype policy, it sets an attribute named _compute_dtype_object, I don't know its usage specifically, but it's used by a lot of functions in the class.

Copy link
Member Author

@mattdangerw mattdangerw Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing with keras.src.dtype_policies.get(dtype) is we want to avoid private functionality. We only have one case where we are using that today, and it's much more self contained (lora standalone weight saving). I think the only case we are missing here is loading from a serialized dict, which we should not hit because we don't serialize this in the backbone config.

We should probably add public keras.dtype_policies.get and keras.dtype_policies.serailize to core Keras for consistency with other things like initializers.

For Keras 2 sounds good, let's switch to self._set_dtype_policy(dtype). Less concern over non-exported functionality there Keras 2 is quite static.

# Keras 2 and Keras 3 handle setting policy differently.
if config.keras_3():
if isinstance(dtype, keras.DTypePolicy):
self.dtype_policy = dtype
else:
self.dtype_policy = keras.DTypePolicy(dtype)
else:
self._set_dtype_policy(dtype)

def __dir__(self):
if config.keras_3():
Expand Down Expand Up @@ -67,7 +76,7 @@ def token_embedding(self):

This layer embeds integer token ids to the hidden dim of the model.
"""
return self._token_embedding
return getattr(self, "_token_embedding", None)

@token_embedding.setter
def token_embedding(self, value):
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/bart/bart_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/bert/bert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(
"sequence_output": sequence_output,
"pooled_output": pooled_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/bloom/bloom_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/deberta_v3/deberta_v3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=x,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/distil_bert/distil_bert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=x,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/electra/electra_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(
"sequence_output": sequence_output,
"pooled_output": pooled_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/f_net/f_net_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
"sequence_output": sequence_output,
"pooled_output": pooled_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/falcon/falcon_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
"padding_mask": padding_mask,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/gpt2/gpt2_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/mistral/mistral_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/opt/opt_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=x,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/roberta/roberta_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
"padding_mask": padding_mask_input,
},
outputs=x,
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/t5/t5_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def __init__(
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
10 changes: 8 additions & 2 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def __init__(self, *args, **kwargs):
id(layer) for layer in self._flatten_layers()
)
self._initialized = True
if self.backbone is not None:
# Keras 2 and Keras 3 handle setting policy differently.
if config.keras_3():
self.dtype_policy = self._backbone.dtype_policy
else:
self._set_dtype_policy(self._backbone.dtype_policy)

def __dir__(self):
if config.keras_3():
Expand Down Expand Up @@ -128,7 +134,7 @@ def __setattr__(self, name, value):
@property
def backbone(self):
"""A `keras.Model` instance providing the backbone sub-model."""
return self._backbone
return getattr(self, "_backbone", None)

@backbone.setter
def backbone(self, value):
Expand All @@ -137,7 +143,7 @@ def backbone(self, value):
@property
def preprocessor(self):
"""A `keras.layers.Layer` instance used to preprocess inputs."""
return self._preprocessor
return getattr(self, "_preprocessor", None)

@preprocessor.setter
def preprocessor(self, value):
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/whisper/whisper_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(
"encoder_sequence_output": encoder_output,
"decoder_sequence_output": decoder_output,
},
dtype=dtype,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/models/xlnet/xlnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __init__(
"segment_ids": segment_id_input,
},
outputs=output,
dtype=dtype,
**kwargs,
)

Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,10 @@ def run_precision_test(self, cls, init_kwargs, input_data):
for weight in layer.weights:
if is_float_dtype(weight.dtype):
self.assertDTypeEqual(weight, policy.variable_dtype)
for sublayer in layer._flatten_layers(include_self=False):
if isinstance(
sublayer, (keras.layers.Softmax, keras.layers.InputLayer)
):
for sublayer in layer._flatten_layers():
if isinstance(sublayer, keras.layers.Softmax):
continue
if isinstance(sublayer, keras.layers.InputLayer):
continue
self.assertEqual(policy.compute_dtype, sublayer.compute_dtype)
self.assertEqual(policy.variable_dtype, sublayer.variable_dtype)
Expand Down
Loading