Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quantization support for Gemma, Gemma2 and PaliGemma #1670

Merged
merged 7 commits into from
Jul 3, 2024

Conversation

james77777778
Copy link
Collaborator

@james77777778 james77777778 commented Jun 22, 2024

We will need a new release of Keras for this. Currently, I have built the PR based on the master branch of Keras.

The implementation is simple and clean after introducing DTypePolicyMap and some other fixes.
Thanks to @fchollet and @mattdangerw for their help.

It is worth noting that float8 training & inference are also supported in this PR. You can check test_quantize for this.

Some numbers:

Model Memory Usage (bfloat16) Memory Usage (int8) Weights (kagglehub) Weights (int8) Note
"gemma_1.1_instruct_2b_en" 5.69GB 2.82GB 4.7GB 2.4GB
"gemma2_instruct_9b_en" 20.93GB 10.14GB 18GB 8.7GB Measured on CPU
"pali_gemma_3b_mix_224" 6.52GB 3.22GB 5.5GB 2.8GB

Script:

int8_gemma.py
import argparse
import os
import pathlib
import time
import typing

import keras
import psutil
import tensorflow as tf

import keras_nlp

# Setup kaggle information
os.environ["KAGGLE_USERNAME"] = "xxx"
os.environ["KAGGLE_KEY"] = "xxx"


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        default="pali_gemma_3b_mix_224",
        choices=[
            "gemma_1.1_instruct_2b_en",
            "pali_gemma_3b_mix_224",
            "gemma2_instruct_9b_en",
        ],
        help="Which model to demonstrate",
    )
    parser.add_argument(
        "--path",
        default=".",
        help="Path to save and load the model",
    )
    parser.add_argument(
        "--save",
        action="store_true",
        help="Quantize and save the model",
    )
    args = parser.parse_args()
    return args


def get_memory_usage():
    # From CPU or GPU:0
    try:
        memory_stats = tf.config.experimental.get_memory_info("GPU:0")
        peak_usage = memory_stats["peak"] / (2**30)
    except Exception:
        memory_usage = psutil.Process().memory_info().rss
        peak_usage = memory_usage / (2**30)
    return peak_usage


def benchmark_pali_gemma(
    model: keras_nlp.models.PaliGemmaCausalLM, image, prompt: str
):
    # Warmup
    model.generate({"images": image, "prompts": prompt}, max_length=128)

    # Benchmark
    st = time.time()
    result = model.generate(
        {"images": image, "prompts": prompt}, max_length=128
    )
    ed = time.time()
    return result, ed - st


def benchmark_gemma(model: keras_nlp.models.GemmaCausalLM, prompt: str):
    # Warmup
    model.generate(prompt, max_length=128)

    # Benchmark
    st = time.time()
    result = model.generate(prompt, max_length=128)
    ed = time.time()
    return result, ed - st


def save_int8_model(
    preset: str,
    model: typing.Union[
        keras_nlp.models.GemmaCausalLM,
        keras_nlp.models.PaliGemmaCausalLM,
    ],
):
    model.quantize("int8")
    model.summary()
    model.save(f"{preset}_int8.keras")


def load(model_path: pathlib.Path):
    model = keras.saving.load_model(model_path)
    return model


if __name__ == "__main__":
    keras.config.set_dtype_policy("bfloat16")
    x = keras.ops.ones([1]) * keras.ops.ones([1])  # Trigger TF dummy logs

    args = get_args()
    path = pathlib.Path(args.path)
    is_pali_gemma = "pali_gemma" in str(args.model)
    print(f"Peak memory usage (init): {get_memory_usage():.3f} GB")

    # Save
    if args.save:
        if is_pali_gemma:
            model = keras_nlp.models.PaliGemmaCausalLM.from_preset(args.model)
        else:
            model = keras_nlp.models.GemmaCausalLM.from_preset(args.model)
        model.summary()
        print(
            "Peak memory usage (loaded float model): "
            f"{get_memory_usage():.3f} GB"
        )
        save_int8_model(args.model, model)
    # Load
    else:
        model_path = path / f"{args.model}_int8.keras"
        model = load(model_path)
        print(
            "Peak memory usage (loaded int8 model): "
            f"{get_memory_usage():.3f} GB"
        )

        if is_pali_gemma:
            image_path = keras.utils.get_file(
                "cow_beach_1.png",
                "https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png",
            )
            image = keras.utils.load_img(image_path)
            image = keras.utils.img_to_array(image, "channels_last")
            prompt = "describe en\n"
            result, elapsed_time = benchmark_pali_gemma(model, image, prompt)
        else:
            prompt = "What is Keras3?"
            result, elapsed_time = benchmark_gemma(model, prompt)
        print(result)
        print(
            f"The elapsed time for model inference: {elapsed_time:.3f} seconds"
        )

Usage:

# Get quantized model
python int8_gemma.py --model "gemma_1.1_instruct_2b_en" --save
python int8_gemma.py --model "gemma2_instruct_9b_en" --save
python int8_gemma.py --model "pali_gemma_3b_mix_224" --save
# Run
python int8_gemma.py --model "gemma_1.1_instruct_2b_en"
python int8_gemma.py --model "gemma2_instruct_9b_en"
python int8_gemma.py --model "pali_gemma_3b_mix_224"

Outputs:

# Gemma
What is Keras3?

Keras3 is a high-level neural network library built on top of Keras 2. It provides a simplified and more efficient way to build and train deep learning models.

**Key features of Keras3:**

- Simplified API with Keras 2 compatibility
- High-level abstractions for common tasks
- Improved performance and efficiency
- Support for modern neural network architectures


**Benefits of using Keras3:**

- Easier to learn and use
- Faster and more accurate models
- Reduced development time
- Improved portability across different hardware platforms


**How to use Keras3:**

- Import

# PaliGemma
describe en
In this image I can see a cow which affor is in brown color and white color. I can see the sand. In the background I can see the water and the sky.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Jun 22, 2024
@james77777778 james77777778 changed the title [WIP] Add quantization support for Gemma Add quantization support for Gemma and PaliGemma Jun 25, 2024
@james77777778 james77777778 marked this pull request as ready for review June 25, 2024 06:50
@james77777778
Copy link
Collaborator Author

This PR should be ready for reviewing.
Both Gemma and PaliGemma now support quantization (int8 and float8).

@james77777778 james77777778 changed the title Add quantization support for Gemma and PaliGemma Add quantization support for Gemma, Gemma2 and PaliGemma Jun 28, 2024
@james77777778
Copy link
Collaborator Author

Hi @fchollet @mattdangerw
I have added quantization support for Gemma2 (actually, adding tests is sufficient :) )
Please let me know if any updates are needed.

@mattdangerw
Copy link
Member

@james77777778 thanks so much! Sorry for the delay, I was out last week, but just got back in town. Will take a look tomorrow!

@james77777778
Copy link
Collaborator Author

No hurry. Please take your time.

@mattdangerw mattdangerw self-requested a review July 1, 2024 22:24
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

General comments.

  • Let's try to make the contract between the ReversibleEmbedding layer and the Embedding layer as minimal as possible. Any private functionality might change in core Keras, are we are using a lot here (which is fine, let's just reduce if we can).
  • Let's test this on all models if we can.

keras_nlp/src/layers/modeling/reversible_embedding.py Outdated Show resolved Hide resolved

return super()._int8_call(inputs)

def quantize(self, mode):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we chain to super here to keep most of the logic? and just handle the if mode == "int8" and not self.tie_weights case below? Would be great to keep as much logic on the super class as we can.

Copy link
Collaborator Author

@james77777778 james77777778 Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid not.
The raising of NotImplementedError in keras.layers.Embedding is intentional and inevitable. The idea is to prevent undefined behavior when users call Model.quantize.

I can introduce an argument like type_check=True in keras.layers.Embedding to support super in the future.
However, for now, we can only implement quantize from scratch.

EDITED:
keras-team/keras#19949

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for the explainer.

Not to solve on this PR, but I wonder if we can make the contract between Keras and downstream here more public and minimal. I see _int_8_call(), _int_8_build(), _quantization_mode_error(), _tracker, and _untrack_variable() all used here. That's a pretty significant level of private usage, which could easily break.

Separate question, will this work with older version of Keras 3? Or are there small changes we could make so we don't break older versions?

Copy link
Collaborator Author

@james77777778 james77777778 Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that these methods are too verbose for downstream project. I will try to simplify the contract in the future, but currently, I don't have a good idea for it.

will this work with older version of Keras 3?

I haven't check the compatibility. My rough guess is that users will need keras>=3.4.0 due to the introduction of DTypePolicyMap

"name": self.name,
"trainable": self.trainable,
}

# Add quantization support by utilizing `DTypePolicyMap`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! This should buy us support for all models right? If possible we should consider extending our common backbones tests for this...

https://github.com/keras-team/keras-nlp/blob/e4f09b24c699857edae27c8054aab44078e9cbd5/keras_nlp/src/tests/test_case.py#L359-L367

https://github.com/keras-team/keras-nlp/blob/e4f09b24c699857edae27c8054aab44078e9cbd5/keras_nlp/src/models/gemma/gemma_backbone_test.py#L39-L45

Doing so would test quantization for the whole library. Seems like it should be doable, call quantize, asset output. WDYT?

If we run into failures for certain models, we could add an option to run_backbone_test, called run_quantization_check=True, and set the option to false if the model fails, with a TODO to investigate.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is doable.
I have added run_quantization_test to run_backbone_test. Only Bloom and OPT failed the test.
However, there is a significant speed regression after adding this test. The CI time increased from ~19mins to ~27mins. Is this acceptable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having the coverage is important. Let's pull this in, and see if we can improve the runtime efficiency as a follow up.

Saving is slow. So maybe we can just do something like

Something like:

  • Basic quantization tests do not hit saving. Just test get_config(), from_config() maybe assigning weights over.
  • Separate quantization testing in our saving test harness. That is marked with large, and is only run on larger/faster hardware.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try this in another PR.

keras_nlp/src/layers/modeling/reversible_embedding.py Outdated Show resolved Hide resolved
@james77777778 james77777778 force-pushed the quantization-support branch from 6f276c7 to a71f83b Compare July 3, 2024 06:58
@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Jul 3, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jul 3, 2024
@mattdangerw
Copy link
Member

@james77777778 thanks for the changes!

As soon as testing is all green I will pull this in, especially since the US is about to go into holiday until next Monday.

I think the coverage is worth it, but let's keep seeing if we can think of ways to speed up these testing with decent coverage as a follow up.

@mattdangerw mattdangerw merged commit bb423c8 into keras-team:master Jul 3, 2024
8 checks passed
@james77777778 james77777778 deleted the quantization-support branch July 4, 2024 01:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants