-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add quantization support for Gemma
, Gemma2
and PaliGemma
#1670
Add quantization support for Gemma
, Gemma2
and PaliGemma
#1670
Conversation
Gemma
Gemma
and PaliGemma
This PR should be ready for reviewing. |
Gemma
and PaliGemma
Gemma
, Gemma2
and PaliGemma
Hi @fchollet @mattdangerw |
@james77777778 thanks so much! Sorry for the delay, I was out last week, but just got back in town. Will take a look tomorrow! |
No hurry. Please take your time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
General comments.
- Let's try to make the contract between the
ReversibleEmbedding
layer and theEmbedding
layer as minimal as possible. Any private functionality might change in core Keras, are we are using a lot here (which is fine, let's just reduce if we can). - Let's test this on all models if we can.
|
||
return super()._int8_call(inputs) | ||
|
||
def quantize(self, mode): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we chain to super here to keep most of the logic? and just handle the if mode == "int8" and not self.tie_weights
case below? Would be great to keep as much logic on the super class as we can.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid not.
The raising of NotImplementedError
in keras.layers.Embedding
is intentional and inevitable. The idea is to prevent undefined behavior when users call Model.quantize
.
I can introduce an argument like type_check=True
in keras.layers.Embedding
to support super
in the future.
However, for now, we can only implement quantize
from scratch.
EDITED:
keras-team/keras#19949
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for the explainer.
Not to solve on this PR, but I wonder if we can make the contract between Keras and downstream here more public and minimal. I see _int_8_call()
, _int_8_build()
, _quantization_mode_error()
, _tracker
, and _untrack_variable()
all used here. That's a pretty significant level of private usage, which could easily break.
Separate question, will this work with older version of Keras 3? Or are there small changes we could make so we don't break older versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that these methods are too verbose for downstream project. I will try to simplify the contract in the future, but currently, I don't have a good idea for it.
will this work with older version of Keras 3?
I haven't check the compatibility. My rough guess is that users will need keras>=3.4.0
due to the introduction of DTypePolicyMap
"name": self.name, | ||
"trainable": self.trainable, | ||
} | ||
|
||
# Add quantization support by utilizing `DTypePolicyMap` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! This should buy us support for all models right? If possible we should consider extending our common backbones tests for this...
Doing so would test quantization for the whole library. Seems like it should be doable, call quantize, asset output. WDYT?
If we run into failures for certain models, we could add an option to run_backbone_test
, called run_quantization_check=True
, and set the option to false if the model fails, with a TODO to investigate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it is doable.
I have added run_quantization_test
to run_backbone_test
. Only Bloom and OPT failed the test.
However, there is a significant speed regression after adding this test. The CI time increased from ~19mins to ~27mins. Is this acceptable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having the coverage is important. Let's pull this in, and see if we can improve the runtime efficiency as a follow up.
Saving is slow. So maybe we can just do something like
Something like:
- Basic quantization tests do not hit saving. Just test
get_config()
,from_config()
maybe assigning weights over. - Separate quantization testing in our saving test harness. That is marked with
large
, and is only run on larger/faster hardware.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will try this in another PR.
3231e11
to
6f276c7
Compare
6f276c7
to
a71f83b
Compare
@james77777778 thanks for the changes! As soon as testing is all green I will pull this in, especially since the US is about to go into holiday until next Monday. I think the coverage is worth it, but let's keep seeing if we can think of ways to speed up these testing with decent coverage as a follow up. |
We will need a new release of Keras for this. Currently, I have built the PR based on the master branch of Keras.The implementation is simple and clean after introducing
DTypePolicyMap
and some other fixes.Thanks to @fchollet and @mattdangerw for their help.
It is worth noting that float8 training & inference are also supported in this PR. You can check
test_quantize
for this.Some numbers:
Script:
int8_gemma.py
Usage:
Outputs: