From d57704c50c34aabef2458260636278e331383d66 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 13 Jan 2025 10:54:12 -0500 Subject: [PATCH] Update QAT READMEs using new APIs (#1541) * Add convert path for quantize_ QAT API Summary: https://github.com/pytorch/ao/pull/1415 added a quantize_ QAT API for the prepare path. This commit adds the remaining convert path for users to actually perform end-to-end QAT using the quantize_ API. The new flow will look like: ``` from torchao.quantization import ( quantize_, int8_dynamic_activation_int4_weight, ) from torchao.quantization.qat import ( FakeQuantizeConfig, from_intx_quantization_aware_training, intx_quantization_aware_training, ) activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) quantize_( my_model, intx_quantization_aware_training(activation_config, weight_config), ) quantize_(my_model, from_intx_quantization_aware_training()) quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32)) ``` Test Plan: python test/quantization/test_qat.py -k test_quantize_api_convert_path [ghstack-poisoned] * Update QAT READMEs using new APIs Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned] * Update base for Update on "Update QAT READMEs using new APIs" Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned] * Update base for Update on "Update QAT READMEs using new APIs" Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned] * Update base for Update on "Update QAT READMEs using new APIs" Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned] * Update base for Update on "Update QAT READMEs using new APIs" Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned] * Update base for Update on "Update QAT READMEs using new APIs" Add references to new QAT APIs including `quantize_`, `FakeQuantizedX`, and the new embedding Quantizers and ComposableQATQuantizer. Also link to new QAT + LoRA recipe in torchtune. [ghstack-poisoned] --- README.md | 35 ++++-- torchao/quantization/qat/README.md | 189 ++++++++++++++++++++++------- 2 files changed, 167 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index 6ba0e3be4c..0da273f91c 100644 --- a/README.md +++ b/README.md @@ -54,27 +54,38 @@ We've added kv cache quantization and other features in order to enable long con In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md) +## Training + ### Quantization Aware Training -Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/) +Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/). For more details, please see the [QAT README](./torchao/quantization/qat/README.md). ```python -from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer - -qat_quantizer = Int8DynActInt4WeightQATQuantizer() +from torchao.quantization import ( + quantize_, + int8_dynamic_activation_int4_weight, +) +from torchao.quantization.qat import ( + FakeQuantizeConfig, + from_intx_quantization_aware_training, + intx_quantization_aware_training, +) -# Insert "fake quantize" operations into linear layers. -# These operations simulate quantization numerics -model = qat_quantizer.prepare(model) +# Insert fake quantization +activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +weight_config = FakeQuantizeConfig(torch.int4, group_size=32) +quantize_( + my_model, + intx_quantization_aware_training(activation_config, weight_config), +) -# Run Training... +# Run training... (not shown) -# Convert fake quantize to actual quantize operations -model = qat_quantizer.convert(model) +# Convert fake quantization to actual quantized operations +quantize_(my_model, from_intx_quantization_aware_training()) +quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32)) ``` -## Training - ### Float8 [torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433. diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md index 6ecccd2b18..813b628af7 100644 --- a/torchao/quantization/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -19,12 +19,6 @@ x_fq = (x_float / scale + zp).round().clamp(qmin, qmax) x_fq = (x_fq - zp) * scale ``` -## API - -torchao currently supports two QAT schemes for linear layers: -- int8 per token dynamic activations + int4 per group weights -- int4 per group weights (using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training) - QAT typically involves applying a transformation to your model before and after training. In torchao, these are represented as the prepare and convert steps: (1) prepare inserts fake quantize operations into linear layers, and (2) convert transforms the fake quantize @@ -34,64 +28,169 @@ Between these two steps, training can proceed exactly as before. ![qat](images/qat_diagram.png) -To use QAT in torchao, apply the prepare step using the appropriate Quantizer before -training, then apply the convert step after training for inference or generation. -For example, on a single GPU: + +## torchao APIs + +torchao currently supports two QAT APIs, one through the [`quantize_`](https://pytorch.org/ao/stable/generated/torchao.quantization.quantize_.html#torchao.quantization.quantize_) +API (recommended) and one through the Quantizer classes (legacy). The `quantize_` API +allows flexible configuration of quantization settings for both activations and weights, +while the Quantizer classes each hardcode a specific quantization setting. + +For example, running QAT on a single GPU: ```python import torch from torchtune.models.llama3 import llama3 + +# Set up smaller version of llama3 to fit in a single GPU +def get_model(): + return llama3( + vocab_size=4096, + num_layers=16, + num_heads=16, + num_kv_heads=4, + embed_dim=2048, + max_seq_len=2048, + ).cuda() + +# Example training loop +def train_loop(m: torch.nn.Module): + optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + loss_fn = torch.nn.CrossEntropyLoss() + for i in range(10): + example = torch.randint(0, 4096, (2, 16)).cuda() + target = torch.randn((2, 16, 4096)).cuda() + output = m(example) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() +``` + +### quantize_ API (recommended) + +The recommended way to run QAT in torchao is through the `quantize_` API: +1. **Prepare:** specify how weights and/or activations are to be quantized through +[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`intx_quantization_aware_training`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242) +2. **Convert:** quantize the model using the standard post-training quantization (PTQ) +functions such as [`int8_dynamic_activation_int4_weight`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606) + +For example: + + +```python +from torchao.quantization import ( + quantize_, + int8_dynamic_activation_int4_weight, +) +from torchao.quantization.qat import ( + FakeQuantizeConfig, + from_intx_quantization_aware_training, + intx_quantization_aware_training, +) +model = get_model() + +# prepare: insert fake quantization ops +# swaps `torch.nn.Linear` with `FakeQuantizedLinear` +activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +weight_config = FakeQuantizeConfig(torch.int4, group_size=32) +quantize_( + model, + intx_quantization_aware_training(activation_config, weight_config), +) + +# train +train_loop(model) + +# convert: transform fake quantization ops into actual quantized ops +# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts +# quantized activation and weight tensor subclasses +quantize_(model, from_intx_quantization_aware_training()) +quantize_(model, int8_dynamic_activation_int4_weight(group_size=32)) + +# inference or generate +``` + +To fake quantize embedding in addition to linear, you can additionally call +the following with a filter function during the prepare step: + +``` +quantize_( + m, + intx_quantization_aware_training(weight_config=weight_config), + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), +) +``` + + +### Quantizer API (legacy) + +Alternatively, torchao provides a few hardcoded quantization settings through +the following Quantizers: +- [Int8DynActInt4QATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L126) (linear), targeting int8 per-token dynamic asymmetric activation + int4 per-group symmetric weight +- [Int4WeightOnlyQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/linear.py#L308) (linear), targeting int4 per-group asymmetric weight using the efficient [int4 tinygemm kernel](https://github.com/pytorch/pytorch/blob/a672f6c84e318bbf455f13dfdd3fd7c68a388bf5/aten/src/ATen/native/cuda/int4mm.cu#L1097) after training) +- [Int4WeightOnlyEmbeddingQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/embedding.py#L94) (embedding), targeting int4 per-group symmetric weight + +For example: +```python from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer +qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32) +model = get_model() -# Smaller version of llama3 to fit in a single GPU -model = llama3( - vocab_size=4096, - num_layers=16, - num_heads=16, - num_kv_heads=4, - embed_dim=2048, - max_seq_len=2048, -).cuda() - -# Quantizer for int8 dynamic per token activations + -# int4 grouped per channel weights, only for linear layers -qat_quantizer = Int8DynActInt4WeightQATQuantizer() - -# Insert "fake quantize" operations into linear layers. -# These operations simulate quantization numerics during -# training without performing any dtype casting +# prepare: insert fake quantization ops +# swaps `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear` model = qat_quantizer.prepare(model) -# Standard training loop -optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) -loss_fn = torch.nn.CrossEntropyLoss() -for i in range(10): - example = torch.randint(0, 4096, (2, 16)).cuda() - target = torch.randn((2, 16, 4096)).cuda() - output = model(example) - loss = loss_fn(output, target) - loss.backward() - optimizer.step() - optimizer.zero_grad() - -# Convert fake quantize to actual quantize operations -# The quantized model has the exact same structure as the -# quantized model produced in the corresponding PTQ flow -# through `Int8DynActInt4WeightQuantizer` +# train +train_loop(model) + +# convert: transform fake quantization ops into actual quantized ops +# swaps `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear` model = qat_quantizer.convert(model) # inference or generate ``` -Users can also leverage our integration with [torchtune](https://github.com/pytorch/torchtune) -and apply quantized-aware fine-tuning as follows: +To use multiple Quantizers in the same model for different layer types, +users can also leverage the [ComposableQATQuantizer](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L242) +as follows: + +```python +from torchao.quantization.qat import ( + ComposableQATQuantizer, + Int4WeightOnlyEmbeddingQATQuantizer, + Int8DynActInt4WeightQATQuantizer, +) + +quantizer = ComposableQATQuantizer([ + Int8DynActInt4WeightQATQuantizer(groupsize=group_size), + Int4WeightOnlyEmbeddingQATQuantizer(group_size=group_size), +]) + +# prepare + train + convert as before +model = qat_quantizer.prepare(model) +train_loop(model) +model = qat_quantizer.convert(model) +``` + +## torchtune integration + +torchao QAT is integrated with [torchtune](https://github.com/pytorch/torchtune) +to allow users to run quantized-aware fine-tuning as follows: ``` tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full ``` -For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html). +torchtune also supports a [QAT + LoRA distributed training recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py) +that is 1.89x faster and uses 36.1% memory compared to vanilla QAT in our early experiments. +You can read more about it [here](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700): +``` +tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora +``` + +For more detail, please refer to [this QAT tutorial](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html). ## Evaluation Results