-
Notifications
You must be signed in to change notification settings - Fork 278
Sharded weights support #2218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Sharded weights support #2218
Conversation
9c92ba4
to
bf9966a
Compare
@james77777778 thanks will take a look! We don't need to be backwards compatible here, the error message you have which an action the user can take is as good as we can do here I think. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just a couple comments.
keras_hub/src/utils/preset_utils.py
Outdated
dtype = keras.backend.standardize_dtype(dtype) | ||
dtype_size = int( | ||
( | ||
dtype.replace("bfloat", "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain what's going on here? maybe flip this to a dtype_size function (just so you can add a quick docstring?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the code, and it should be more explicit by using regex.
"use_post_attention_norm": True, | ||
"use_sliding_window_attention": True, | ||
} | ||
backbone = GemmaBackbone(**init_kwargs) # ~4.4MB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this even smaller? Feel free to use bert or something simple if its easier. Try to make this run as fast as possible while testing the business logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed the config to make the backbone smaller (422KB). It now takes only 2.5 seconds from the start to the end of the test.
bf9966a
to
2b70a6f
Compare
Please see the colab for an example using Gemma2 2B:
https://colab.research.google.com/drive/1iF_Psb6aEV2pkajT-q9ZBjpoO4RX4-Qa?usp=sharing
This PR adds support for sharded weights in
KerasPresetSaver
andKerasPresetLoader
.The default
max_shard_size
is set to 10GB.Kindly ping @divyashreepathihalli @mattdangerw
Note: This feature requires the latest Keras (
git+https://github.com/keras-team/keras.git
). It is difficult to ensure the backward compatibility.Related to #2084