-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic int8 quantization for Llama2 and Llama3 #1720
Dynamic int8 quantization for Llama2 and Llama3 #1720
Conversation
Do you see noticeably worse result on CPU? Is it backend specific? Not that it would block this, but good to know. |
@james77777778 can you say what you mean here? Do you just mean check them in? If so agreed, will be nice to have. |
Yes, when using cpu and small
I have tested using JAX and TensorFlow and verified that the bad results are backend-agnostic. However, with cuda, the results are more reasonable: JAX: What is Keras?
nobody knows.
It's a neural network framework written in python.
### Installation
Install TensorFlow: What is Keras?
sierp 18, 2018
Keras is a powerful and easy-to-use library for building Env:
Yes! That's what I mean. |
badc1de
to
68dc1fc
Compare
@james77777778 did the output of llama2_7b_en look random pre-quantized? Or just after quantization. Probably worth looking into this, we might have a bug somewhere. |
Anyway, this PR looks good to me, will pull it in! |
I only checked the quantized model. Will take a look for the outputs using bfloat16 on CPU and GPU. |
I ran import keras
import keras_nlp
keras.config.set_floatx("bfloat16")
input_str = "What is Keras?"
length = 32
# llama2_7b_en, llama3_8b_en
causal_lm = keras_nlp.models.CausalLM.from_preset("llama3_8b_en")
keras_output = causal_lm.generate([input_str], max_length=length)
keras_output = keras_output[0]
print("🔶 KerasNLP output:", keras_output) Outputs:
|
I think it would be better to add the quantization scripts for KerasNLP
llama2_7b_en_int8
llama2_instruct_7b_en_int8
llama3_8b_en_int8
llama3_instruct_8b_en_int8
llama2_7b_en_int8
produced some unusual outputs, but based on my past experience, it should generate meaningful results using CUDA and XLA.Please let me know if these models are ready to be pushed to Kaggle.
cc @mattdangerw
EDITED:
Note that I have only verified these results using the master branch of Keras and KerasNLP.