Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype accessors of tasks/backbones #1486

Merged
merged 2 commits into from
Mar 6, 2024

Conversation

mattdangerw
Copy link
Member

@mattdangerw mattdangerw commented Mar 5, 2024

Makes sure we have the correct values for task.compute_dtype, task.variable_dtype, backbone.compute_dtype and backbone.variable_dtype.

@mattdangerw mattdangerw force-pushed the dtype-policy-getters branch from 033f583 to 75ee62c Compare March 5, 2024 02:35
@mattdangerw mattdangerw mentioned this pull request Mar 5, 2024
@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Mar 5, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Mar 5, 2024
Copy link
Contributor

@tirthasheshpatel tirthasheshpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for getting this fixed @mattdangerw!

@@ -28,6 +28,16 @@ def __init__(self, *args, dtype=None, **kwargs):
id(layer) for layer in self._flatten_layers()
)
self._initialized = True
if dtype is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this fix, I have one comment about setting the dtype. if you would like to give it a look, I will appreciate it.
I think we should follow how keras.layer sets the dtype police.
something like this:

if config.keras_3():
    self.dtype_policy = keras.src.dtype_policies.get(dtype)
else:
    self._set_dtype_policy(dtype)

I think that will be better user experience as whatever dtype works for layer should be expected to work for the backbone, and for keras3 for example, keras.src.dtype_policies.get(dtype) covers more scenarios.
also keras2 do some extra stuff after setting the dtype policy, it sets an attribute named _compute_dtype_object, I don't know its usage specifically, but it's used by a lot of functions in the class.

Copy link
Member Author

@mattdangerw mattdangerw Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing with keras.src.dtype_policies.get(dtype) is we want to avoid private functionality. We only have one case where we are using that today, and it's much more self contained (lora standalone weight saving). I think the only case we are missing here is loading from a serialized dict, which we should not hit because we don't serialize this in the backbone config.

We should probably add public keras.dtype_policies.get and keras.dtype_policies.serailize to core Keras for consistency with other things like initializers.

For Keras 2 sounds good, let's switch to self._set_dtype_policy(dtype). Less concern over non-exported functionality there Keras 2 is quite static.

@mattdangerw mattdangerw force-pushed the dtype-policy-getters branch from 9b634e7 to 8b2992a Compare March 6, 2024 01:13
@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Mar 6, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Mar 6, 2024
@mattdangerw mattdangerw merged commit f70b1c0 into keras-team:master Mar 6, 2024
10 checks passed
abuelnasr0 pushed a commit to abuelnasr0/keras-nlp that referenced this pull request Apr 2, 2024
* Fix dtype accessors of tasks/backbones

* Address comments, minor fixes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants