Skip to content

Sharded weights support #2218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

james77777778
Copy link
Collaborator

@james77777778 james77777778 commented Apr 19, 2025

Please see the colab for an example using Gemma2 2B:
https://colab.research.google.com/drive/1iF_Psb6aEV2pkajT-q9ZBjpoO4RX4-Qa?usp=sharing

This PR adds support for sharded weights in KerasPresetSaver and KerasPresetLoader.
The default max_shard_size is set to 10GB.

Kindly ping @divyashreepathihalli @mattdangerw

Note: This feature requires the latest Keras (git+https://github.com/keras-team/keras.git). It is difficult to ensure the backward compatibility.

Related to #2084

@github-actions github-actions bot added the Gemma Gemma model specific issues label Apr 19, 2025
@james77777778 james77777778 added the kokoro:force-run Runs Tests on GPU label Apr 19, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Apr 19, 2025
@james77777778 james77777778 force-pushed the sharded_weights_support branch from 9c92ba4 to bf9966a Compare April 20, 2025 07:08
@james77777778 james77777778 added the kokoro:force-run Runs Tests on GPU label Apr 20, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Apr 20, 2025
@mattdangerw mattdangerw self-requested a review April 20, 2025 22:52
@mattdangerw
Copy link
Member

mattdangerw commented Apr 20, 2025

@james77777778 thanks will take a look! We don't need to be backwards compatible here, the error message you have which an action the user can take is as good as we can do here I think.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a couple comments.

dtype = keras.backend.standardize_dtype(dtype)
dtype_size = int(
(
dtype.replace("bfloat", "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain what's going on here? maybe flip this to a dtype_size function (just so you can add a quick docstring?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the code, and it should be more explicit by using regex.

"use_post_attention_norm": True,
"use_sliding_window_attention": True,
}
backbone = GemmaBackbone(**init_kwargs) # ~4.4MB
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this even smaller? Feel free to use bert or something simple if its easier. Try to make this run as fast as possible while testing the business logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the config to make the backbone smaller (422KB). It now takes only 2.5 seconds from the start to the end of the test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues stat:awaiting keras-eng
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants