Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite metrics module with model subclassing #343

Open
hvgazula opened this issue Jul 18, 2024 · 1 comment
Open

Rewrite metrics module with model subclassing #343

hvgazula opened this issue Jul 18, 2024 · 1 comment
Assignees

Comments

@hvgazula
Copy link
Contributor

hvgazula commented Jul 18, 2024

What would you like changed/added and why?
Rewrite functions in metrics.py with model subclassing. See here. The drawback of the current approach is that metrics have to be kept track of when training models and that can not always be possible.

What would be the benefit?
Currently, loading a model from a previously stored checkpoint (warm_start) throws the following warning

WARNING:tensorflow:Unable to restore custom metric. Please ensure that the layer implements get_config and from_config when saving. In addition, please use the custom_objects arg when calling load_model().

Does the change make something easier to use?
This change will avoid situations where the user has to document the metric used and subsequently pass it again at load time. Model sub-classing (with @keras.saving.register_keras_serializable) will greatly simplify this process.

@hvgazula hvgazula self-assigned this Jul 18, 2024
@hvgazula
Copy link
Contributor Author

NOTE: dice and tversky loss have been introduced in tf 2.16..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant