-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dtype accessors of tasks/backbones #1486
Fix dtype accessors of tasks/backbones #1486
Conversation
033f583
to
75ee62c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for getting this fixed @mattdangerw!
@@ -28,6 +28,16 @@ def __init__(self, *args, dtype=None, **kwargs): | |||
id(layer) for layer in self._flatten_layers() | |||
) | |||
self._initialized = True | |||
if dtype is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this fix, I have one comment about setting the dtype. if you would like to give it a look, I will appreciate it.
I think we should follow how keras.layer
sets the dtype police.
something like this:
if config.keras_3():
self.dtype_policy = keras.src.dtype_policies.get(dtype)
else:
self._set_dtype_policy(dtype)
I think that will be better user experience as whatever dtype
works for layer should be expected to work for the backbone, and for keras3 for example, keras.src.dtype_policies.get(dtype)
covers more scenarios.
also keras2 do some extra stuff after setting the dtype policy, it sets an attribute named _compute_dtype_object
, I don't know its usage specifically, but it's used by a lot of functions in the class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only thing with keras.src.dtype_policies.get(dtype)
is we want to avoid private functionality. We only have one case where we are using that today, and it's much more self contained (lora standalone weight saving). I think the only case we are missing here is loading from a serialized dict, which we should not hit because we don't serialize this in the backbone config.
We should probably add public keras.dtype_policies.get
and keras.dtype_policies.serailize
to core Keras for consistency with other things like initializers.
For Keras 2 sounds good, let's switch to self._set_dtype_policy(dtype)
. Less concern over non-exported functionality there Keras 2 is quite static.
9b634e7
to
8b2992a
Compare
* Fix dtype accessors of tasks/backbones * Address comments, minor fixes
Makes sure we have the correct values for
task.compute_dtype
,task.variable_dtype
,backbone.compute_dtype
andbackbone.variable_dtype
.