Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic int8 quantization for Llama2 and Llama3 #1720

Merged
merged 1 commit into from
Aug 1, 2024

Conversation

james77777778
Copy link
Collaborator

@james77777778 james77777778 commented Jul 31, 2024

I think it would be better to add the quantization scripts for KerasNLP

Model Validation Outputs File Size
llama2_7b_en_int8 What is Keras?\n Hinweis: Diese Seite wurde zuletzt am 27. April 2017 aktualisiert. 6.3G
llama2_instruct_7b_en_int8 What is Keras?\n Unterscheidung between Keras und TensorFlow\n Keras is a high-level neural networks API written in Python, capable of 6.3G
llama3_8b_en_int8 What is Keras? Keras is a high-level neural networks API, written in Python and capable of running on top of either TensorFlow , an industry-st 7.5G
llama3_instruct_8b_en_int8 What is Keras? - A High-Level Neural Networks API\n\nKeras is a high-level neural networks API that is written in Python and capable of running 7.5G

llama2_7b_en_int8 produced some unusual outputs, but based on my past experience, it should generate meaningful results using CUDA and XLA.

Please let me know if these models are ready to be pushed to Kaggle.

cc @mattdangerw

EDITED:
Note that I have only verified these results using the master branch of Keras and KerasNLP.

@mattdangerw
Copy link
Member

llama2_7b_en_int8 produced some unusual outputs, but based on my past experience, it should generate meaningful results using CUDA and XLA.

Do you see noticeably worse result on CPU? Is it backend specific? Not that it would block this, but good to know.

@mattdangerw
Copy link
Member

I think it would be better to add the quantization scripts for KerasNLP

@james77777778 can you say what you mean here? Do you just mean check them in? If so agreed, will be nice to have.

@james77777778
Copy link
Collaborator Author

Do you see noticeably worse result on CPU?

Yes, when using cpu and small max_length, it often outputs bad results. (only saw this in llama2_7b_en)
BTW, the outputs of llama seem quite random.

Is it backend specific?

I have tested using JAX and TensorFlow and verified that the bad results are backend-agnostic.

However, with cuda, the results are more reasonable:

JAX:

What is Keras?
 nobody knows.

It's a neural network framework written in python.

### Installation

Install

TensorFlow:

What is Keras?
 sierp 18, 2018
Keras is a powerful and easy-to-use library for building

Env:

  • AMD R9 7900
  • RTX 4070
  • tensorflow 2.16.1
  • jax 0.4.27
  • bfloat16

Do you just mean check them in?

Yes! That's what I mean.

@mattdangerw
Copy link
Member

@james77777778 did the output of llama2_7b_en look random pre-quantized? Or just after quantization.

Probably worth looking into this, we might have a bug somewhere.

@mattdangerw
Copy link
Member

Anyway, this PR looks good to me, will pull it in!

@mattdangerw mattdangerw merged commit 30c480c into keras-team:master Aug 1, 2024
7 checks passed
@james77777778
Copy link
Collaborator Author

@james77777778 did the output of llama2_7b_en look random pre-quantized? Or just after quantization.

I only checked the quantized model. Will take a look for the outputs using bfloat16 on CPU and GPU.

@james77777778 james77777778 deleted the add-quantized-llama branch August 2, 2024 02:30
@james77777778
Copy link
Collaborator Author

@mattdangerw

I ran llama2_7b_en and llama3_8b_en 3 times using JAX and CPU. The results seem meaningful but random.

import keras

import keras_nlp

keras.config.set_floatx("bfloat16")
input_str = "What is Keras?"
length = 32
# llama2_7b_en, llama3_8b_en
causal_lm = keras_nlp.models.CausalLM.from_preset("llama3_8b_en")
keras_output = causal_lm.generate([input_str], max_length=length)
keras_output = keras_output[0]
print("🔶 KerasNLP output:", keras_output)

Outputs:

Backend Model Index Outputs
JAX llama2_7b_en 0 Keras is the deep learning library for python. It is built on the top of the TensorFlow library and is very
JAX llama2_7b_en 1 Keras is an open source Python deep learning library. It is used for building and training neural networks. The library is
JAX llama2_7b_en 2 Keras is an open source neural network API for Python written by François Chollet. Keras is used for deep
JAX llama3_8b_en 0 Keras is a high-level neural network API written in Python and capable of running on top of TensorFlow, CNTK, or
JAX llama3_8b_en 1 A Beginner’s Introduction Keras is an open-source neural network library that can be run on top of Tensorflow, CNT
JAX llama3_8b_en 2 Keras is a powerful and easy-to-use deep learning library for Theano and TensorFlow that provides a high-level neural networks API to

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants