-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Mistral memory consumption with JAX and default dtype bug #1460
Fix Mistral memory consumption with JAX and default dtype bug #1460
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good! one comment on conversion script
@@ -300,7 +300,7 @@ def main(_): | |||
print("-> Saved the model weights in float16") | |||
|
|||
# === Save the model config === | |||
keras_nlp_config["dtype"] = "bfloat16" | |||
keras_nlp_config.pop("dtype") # We don't want a default dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would actually rename keras_nlp_config
-> backbone_kwargs
, and save using
keras_nlp.src.utils.preset_utils.save_to_preset(
keras_nlp_model, preset
)
keras_nlp.src.utils.preset_utils.save_to_preset(
keras_nlp_tokenizer, preset, config_filename="tokenizer.json"
)
Those will do the right thing and always call model.get_config()
to create the config.json
file (which does not include dtype
for this reason). No need to regenerate presets if things look good in testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, uploaded the new presets to Kaggle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Fixes #1458
This PR updates the presets for the Mistral model. The configs have been updated on Kaggle to not set a default dtype. The JAX memory consumption bug should also be fixed now.