-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port mistral transformer checkpoint #1768
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall! One comment.
Could you include a small colab showing generation just to verify this is working? Since we don't have numerics validation yet.
"rope_max_wavelength": transformers_config["rope_theta"], | ||
"layer_norm_epsilon": transformers_config["rms_norm_eps"], | ||
"sliding_window": transformers_config["sliding_window"], | ||
"dtype": transformers_config["torch_dtype"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should convert dtype. We don't for other models.
We will create a backbone with the default Keras floating point type, unless the user supplies their own arg. But we don't restore to the saved dtypes policy by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without dtype conversion, I am getting an error: DTypePromotionError: The DTypes <class 'numpy.dtypes.Float16DType'> and <class 'numpy.dtype[bfloat16]'> do not have a common DType. For example they cannot be stored in a single array unless the dtype is object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. I think this is something we will have to solve during weight conversion, and not by sticking this value in the config. I will take a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mattdangerw do you think this would be a better place to keep the check?
https://github.com/keras-team/keras-nlp/blob/f80fbfd0eaeee7a9e63a4c98a81ff8aba5506f3e/keras_nlp/src/utils/transformers/safetensor_utils.py#L97
We can check if the dtypes match here -- if there is a conversion needed, warn the user that there is a type conversion happening at this stage to port the weights and then continue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ariG23498 Makes sense. 💡
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would actually think any type conversion should happen inside of the assign
call. I tried removing this dtype line and could not reproduce the error. Is this only on a specific backend?
I don't think we need to warn that type conversion is happening. Loading a half precision save at full precision or vice versa is quite common.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we can just remove this line right? I'll give that a try, and land if things look good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I checked it now by removing this line. It works.
|
class TestTask(TestCase): | ||
@pytest.mark.large | ||
def test_convert_tiny_preset(self): | ||
model = MistralCausalLM.from_preset("hf://mistralai/Mistral-7B-v0.1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too big to run in our automated testing regularly. @ariG23498 can you detail what you did to make hf://ariG23498/tiny-gemma-test
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a detailed code to build a small test model and how to upload that to hub.
@cosmo3769 could you take a look at it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ariG23498 Sure.
Main thing we need before we merge is a smaller test case. Left a common on the big chain though, still not sure exactly where things are breaking if you remove dtype from the config. |
Added tiny-mistral test. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm! Will merge after test runs
Resolved the merge conflict. |
Jax failure is from #1783, but this one looks good. Pulling this in! |
* ported mistral * update test * fix config * fix typo * switched float32 to float16 * tiny-mistral-test * removed dtype config
Hi @mattdangerw @ariG23498,
Ported mistral transformers checkpoint in kerasNLP. Please check. Thank you!