From af83deb4dd02600fd4d2a66096906f81881769f0 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Fri, 19 Jul 2024 13:56:36 -0700 Subject: [PATCH 01/12] milestone1: naive_intNwo + eval/benchmark --- .../run_layer_sensitive_study.sh | 22 + .../prototype/mixed_precision/run_mp_quant.sh | 30 ++ .../mixed_precision/run_mp_quant_benchmark.sh | 9 + .../mixed_precision/run_sensi_linear_type.sh | 26 ++ .../mixed_precision/run_uni_quant.sh | 39 ++ .../run_uni_quant_benchmark.sh | 10 + .../mixed_precision/scripts/generate.py | 390 ++++++++++++++++++ .../mixed_precision/scripts/mp_quant_eval.py | 125 ++++++ .../mixed_precision/scripts/naive_intNwo.py | 40 ++ .../scripts/quant_model_size.py | 35 ++ .../scripts/sensitivity_study.py | 95 +++++ .../scripts/test_naive_intNwo.py | 27 ++ 12 files changed, 848 insertions(+) create mode 100755 torchao/quantization/prototype/mixed_precision/run_layer_sensitive_study.sh create mode 100755 torchao/quantization/prototype/mixed_precision/run_mp_quant.sh create mode 100755 torchao/quantization/prototype/mixed_precision/run_mp_quant_benchmark.sh create mode 100755 torchao/quantization/prototype/mixed_precision/run_sensi_linear_type.sh create mode 100755 torchao/quantization/prototype/mixed_precision/run_uni_quant.sh create mode 100755 torchao/quantization/prototype/mixed_precision/run_uni_quant_benchmark.sh create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/generate.py create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/quant_model_size.py create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/sensitivity_study.py create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/test_naive_intNwo.py diff --git a/torchao/quantization/prototype/mixed_precision/run_layer_sensitive_study.sh b/torchao/quantization/prototype/mixed_precision/run_layer_sensitive_study.sh new file mode 100755 index 0000000000..b5bc0f62d3 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/run_layer_sensitive_study.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +GPUS=(0 1 2 3 4 5) + +CONFIGS1=("2" "3" "4" "5" "6" "8") + +PYTHON_SCRIPT="scripts/sensitivity_study.py" + +for LAYER in {0..31}; do + for i in "${!GPUS[@]}"; do + GPU="${GPUS[$i]}" + CONFIG1="${CONFIGS1[$i]}" + + LOG_FILE="Sensi_${LAYER}_${CONFIG1}.txt" + + CUDA_VISIBLE_DEVICES=$GPU python $PYTHON_SCRIPT --repo_id=checkpoints/meta-llama/Meta-Llama-3-8B --quantization=$CONFIG1 --layer=$LAYER &>"$LOG_FILE" & + done + + wait +done + +echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/run_mp_quant.sh b/torchao/quantization/prototype/mixed_precision/run_mp_quant.sh new file mode 100755 index 0000000000..263e0697d2 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/run_mp_quant.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +# List of GPUs to use +GPUS=(0 1 2 3 4 5) + +# List of configuration files +CONFIGS1=(8 8 8 8 8 16) +CONFIGS2=(6 5 4 3 2 8) + +#CONFIGS1=(16 16 16 16 16) +#CONFIGS2=(6 5 4 3 2) + +#CONFIGS1=(5 5 5 6 6 6 6 3) +#CONFIGS2=(4 3 2 5 4 3 2 2) + +PYTHON_SCRIPT="scripts/mp_quant_eval.py" + +for i in "${!GPUS[@]}"; do + GPU="${GPUS[$i]}" + CONFIG1="${CONFIGS1[$i]}" + CONFIG2="${CONFIGS2[$i]}" + + LOG_FILE="MP_${CONFIG1}_${CONFIG2}.txt" + + CUDA_VISIBLE_DEVICES=$GPU python $PYTHON_SCRIPT --repo_id=checkpoints/meta-llama/Meta-Llama-3-8B --quantization=MP_llama3 --sensi_bit=$CONFIG1 --non_sensi_bit=$CONFIG2 &>"$LOG_FILE" & +done + +wait + +echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/run_mp_quant_benchmark.sh b/torchao/quantization/prototype/mixed_precision/run_mp_quant_benchmark.sh new file mode 100755 index 0000000000..653eac112f --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/run_mp_quant_benchmark.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +PYTHON_SCRIPT="scripts/generate.py" +python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=MP_llama3 --sensi_bit=5 --non_sensi_bit=4 --write_result mp_quant_benchmark_results.txt +python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=MP_llama3 --sensi_bit=4 --non_sensi_bit=3 --write_result mp_quant_benchmark_results.txt +python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=MP_llama3 --sensi_bit=5 --non_sensi_bit=3 --write_result mp_quant_benchmark_results.txt + + +echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/run_sensi_linear_type.sh b/torchao/quantization/prototype/mixed_precision/run_sensi_linear_type.sh new file mode 100755 index 0000000000..e7a6845aa4 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/run_sensi_linear_type.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +CONFIGS1=("q_proj" "k_proj" "v_proj" "o_proj" "gate_proj" "up_proj" "down_proj") +CONFIGS2=("2" "3" "4" "5" "6" "8") + +PYTHON_SCRIPT="scripts/sensitivity_study.py" + +GPUS=(0 1 2 3 4 5) + +for i in "${!CONFIGS1[@]}"; do + CONFIG1="${CONFIGS1[$i]}" + + for j in "${!CONFIGS2[@]}"; do + CONFIG2="${CONFIGS2[$j]}" + GPU="${GPUS[$j]}" + + LOG_FILE="Sensi_skipsensi_${CONFIG1}_${CONFIG2}.txt" + + CUDA_VISIBLE_DEVICES=$GPU python $PYTHON_SCRIPT --repo_id=checkpoints/meta-llama/Meta-Llama-3-8B --quantization=$CONFIG2 --linear_type=$CONFIG1 &>"$LOG_FILE" & + done + + wait +done + + +echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/run_uni_quant.sh b/torchao/quantization/prototype/mixed_precision/run_uni_quant.sh new file mode 100755 index 0000000000..13ac196ce8 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/run_uni_quant.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# List of GPUs to use +GPUS=(2 3 4 5 6 7) + +# List of configuration files +#CONFIGS1=("int6wo" "int5wo" "int2wo" "int3wo" "int8wo" "int4wo" "None" "autoquant") +CONFIGS1=("2" "3" "4" "5" "6" "8") +#CONFIGS1=(8 8 8 8 8 4 4 16) +#CONFIGS2=(6 5 4 3 2 3 2 8) + +#CONFIGS1=(16 16 16 16 16) +#CONFIGS2=(6 5 4 3 2) + +PYTHON_SCRIPT="scripts/mx_eval.py" + +for i in "${!GPUS[@]}"; do + GPU="${GPUS[$i]}" + CONFIG1="${CONFIGS1[$i]}" + + LOG_FILE="UNI_${CONFIG1}_SYM.txt" + + CUDA_VISIBLE_DEVICES=$GPU python $PYTHON_SCRIPT --repo_id=checkpoints/meta-llama/Meta-Llama-3-8B --quantization=$CONFIG1 --quant_sym=sym &>"$LOG_FILE" & +done + +wait + +for i in "${!GPUS[@]}"; do + GPU="${GPUS[$i]}" + CONFIG1="${CONFIGS1[$i]}" + + LOG_FILE="UNI_${CONFIG1}_ASYM.txt" + + CUDA_VISIBLE_DEVICES=$GPU python $PYTHON_SCRIPT --repo_id=checkpoints/meta-llama/Meta-Llama-3-8B --quantization=$CONFIG1 --quant_sym=asym &>"$LOG_FILE" & +done + +wait + +echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/run_uni_quant_benchmark.sh b/torchao/quantization/prototype/mixed_precision/run_uni_quant_benchmark.sh new file mode 100755 index 0000000000..6786857076 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/run_uni_quant_benchmark.sh @@ -0,0 +1,10 @@ +#!/bin/bash +PYTHON_SCRIPT="scripts/generate.py" +python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=int4wo --write_result uni_quant_benchmark_results.txt +python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=int8wo --write_result uni_quant_benchmark_results.txt +python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=2 --write_result uni_quant_benchmark_results.txt +python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=3 --write_result uni_quant_benchmark_results.txt +python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=5 --write_result uni_quant_benchmark_results.txt +python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=6 --write_result uni_quant_benchmark_results.txt + +echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/scripts/generate.py b/torchao/quantization/prototype/mixed_precision/scripts/generate.py new file mode 100644 index 0000000000..09f0b9613c --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/generate.py @@ -0,0 +1,390 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import sys +import time +from pathlib import Path +from typing import Optional, Tuple +from datetime import datetime +import torch +import torchao +import torch._dynamo.config +import torch._inductor.config +from torchao.utils import get_model_size_in_bytes + +import torch.nn as nn + +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) + +from naive_intNwo import intN_weight_only_asym, intN_weight_only_sym + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif ("cpu" in device) or ("mps" in device): + pass + else: + print(f"device={device} is not yet suppported") + +default_device = 'cuda' if torch.cuda.is_available() else 'cpu' + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.tokenizer import get_tokenizer + +def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + +def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + +def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + +def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + *, + interactive: bool, + callback = lambda x: x, + **sampling_kwargs +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + # create an empty tensor of the expected final shape and fill in the current tokens + device = prompt.device + T = prompt.numel() + T_new = T + max_new_tokens + seq = torch.empty(T_new, dtype=prompt.dtype, device=device) + seq[:T] = prompt.view(-1) + + # setup model cache + max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + # format model input + x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens) + + # execute prefill + next_token = prefill(model, x, input_pos, **sampling_kwargs).clone() + seq[T] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + seq[T + 1:] = torch.cat(generated_tokens) + + return seq + +def encode_tokens(tokenizer, string, bos=True, device=default_device): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + +def _load_model(checkpoint_path, device, precision): + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + + model = Transformer.from_name(checkpoint_path.parent.name) + model.load_state_dict(checkpoint, assign=True) + model = model.to(device=device, dtype=precision) + + return model.eval() + +B_INST, E_INST = "[INST]", "[/INST]" + +def main( + prompt: str = "Hello, my name is", + interactive: bool = False, + num_samples: int = 5, + max_new_tokens: int = 100, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path("checkpoints/meta-llama/Meta-Llama-3-8B/model.pth"), + quantization: Optional[str] = None, + compile: bool = True, + compile_prefill: bool = False, + profile: Optional[Path] = None, + device=default_device, + precision=torch.bfloat16, + write_result: Optional[Path] = None, + sensi_bit:Optional[int] = None, + non_sensi_bit:Optional[int] = None, + compile_mode:Optional[str] = "reduce-overhead", + group_size:Optional[int] = 32, +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer. + """ + + torchao.quantization.utils.recommended_inductor_config_setter() + + assert checkpoint_path.is_file(), checkpoint_path + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + + print(f"Using device={device}") + is_chat = "chat" in str(checkpoint_path) + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision) + + + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + + if quantization: + from torchao.quantization.quant_api import ( + quantize_, + int8_weight_only, + int8_dynamic_activation_int8_weight, + int4_weight_only, + autoquant, + unwrap_tensor_subclass + ) + if "int8wo" in quantization: + quantize_(model, int8_weight_only_my()) + if "int8dq" in quantization: + quantize_(model, int8_dynamic_activation_int8_weight()) + if "int4wo" in quantization: + groupsize=int(quantization.split("-")[-1]) + assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" + quantize_(model, int4_weight_only(group_size=groupsize)) + if "autoquant" == quantization: + model = autoquant(model, manual=True) + + if quantization in ["2","3","4","5","6","8"]: + quantize_(model.to(device=device), intN_weight_only_sym(n=int(quantization), group_size=group_size)) + + elif quantization == "MP_llama3": + # filter for sensitive layers + def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool: + return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']) + + # filter for non-sensitive layers + def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool: + return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])) + + #quantize the sensitive layers + if sensi_bit == 8: + quantize_(model.to(device=device), int8_weight_only(), filter_fn_sen) + elif sensi_bit == 4: + quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_sen) + elif sensi_bit in [6,5,3,2]: + quantize_(model.to(device=device), intN_weight_only_asym(n=sensi_bit, group_size=group_size), filter_fn_sen) + + #quantize the less-sensitive layers + if non_sensi_bit == 8: + quantize_(model.to(device=device), int8_weight_only(), filter_fn_nonsen) + elif non_sensi_bit == 4: + quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_nonsen) + elif non_sensi_bit in [6,5,3,2]: + if sensi_bit == 4: + quantize_(model, intN_weight_only_asym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) + else: + quantize_(model.to(device=device), intN_weight_only_asym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) + + else: + unwrap_tensor_subclass(model) + + model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 + + if compile: + print("Compiling Model") + global decode_one_token, prefill + decode_one_token = torch.compile(decode_one_token, mode=compile_mode, fullgraph=True) + + if compile_prefill: + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + + aggregate_metrics = { + 'tokens_per_sec': [], + } + start = -1 if compile else 0 + + for i in range(start, num_samples): + if i==0: + torch.cuda.reset_peak_memory_stats() + device_sync(device=device) # MKG + if i >= 0 and interactive: + prompt = input("What is your prompt? ") + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode('.')[0] + done_generating = False + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print(''.join(buffer), end='', flush=True) + buffer.clear() + # print(, end='', flush=True) + else: + callback = lambda x : x + t0 = time.perf_counter() + import contextlib + if (i != num_samples - 1 or not profile): + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y = generate( + model, + encoded, + max_new_tokens, + interactive=interactive, + callback=callback, + temperature=temperature, + top_k=top_k, + ) + + if i == -1: + compile_time = time.perf_counter() - t0 + print(f"Compilation time: {compile_time} seconds") + continue + if hasattr(prof, "export_chrome_trace"): + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + if not interactive: + print(tokenizer.decode(y.tolist())) + else: + print() + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics['tokens_per_sec'].append(tokens_sec) + print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") + print("==========") + + tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() + bandwidth = model_size * tokpersec + mem = torch.cuda.max_memory_reserved() /1e9 + print(f"Average tokens/sec: {tokpersec:.2f}") + print(f"Average Bandwidth: {bandwidth:.02f} GB/s") + print(f"Peak Memory Usage: {mem:.02f} GB") + print(f"Model Size: {model_size:.02f} GB") + if write_result: + result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " + result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device}, compilation time: {compile_time} " + result_txt += f"repro: python generate.py " + result_txt += f"--quantization {quantization} " if quantization else "" + result_txt += f"--sensi_bit {sensi_bit} " if sensi_bit else "" + result_txt += f"--non_sensi_bit {non_sensi_bit} " if non_sensi_bit else "" + result_txt += f"--checkpoint_path {checkpoint_path} " + result_txt += f"--device {device} " + result_txt += f"--precision {precision} " + result_txt += f"--compile " if compile else "" + result_txt += f"--compile_prefill " if compile_prefill else "" + result_txt += f"--profile {profile} " if profile else "" + result_txt += f"--interactive " if interactive else "" + result_txt += f"--num_samples {num_samples} " + result_txt += f"--max_new_tokens {max_new_tokens} " + result_txt += f"--top_k {top_k} " + result_txt += f"--temperature {temperature} " + f=open(write_result, "a") + f.write(result_txt) + f.close() + + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + + parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') + parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') + parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') + parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') + parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') + parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Meta-Llama-3-8B/model.pth"), help='Model checkpoint path.') + #parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') + parser.add_argument('-q', '--quantization', default = "None", help='Which quantization technique to apply') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') + parser.add_argument('--profile', type=Path, default=None, help='Profile path.') + parser.add_argument('--device', type=str, default=default_device, help='Device to use') + parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') + parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result') + parser.add_argument('--sensi_bit', type=int, default=16, help='Bit setting for sensitive layers') + parser.add_argument('--non_sensi_bit', type=int, default=16, help='Bit setting for non-sensitive layers') + parser.add_argument('--compile_mode', type=str, default="max-autotune", help='max-autotune or reduce-overhead mode for torch.compile()') + parser.add_argument('--group_size', type=int, default=32, help='group size to perform quantization on') + args = parser.parse_args() + main( + args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, + args.temperature, args.checkpoint_path, args.quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result, args.sensi_bit, args.non_sensi_bit, args.compile_mode, args.group_size + ) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py b/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py new file mode 100644 index 0000000000..0caee20dae --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn + +from naive_intNwo import intN_weight_only_asym, intN_weight_only_sym +from transformers import AutoModelForCausalLM, AutoTokenizer + +from lm_eval.models.huggingface import HFLM +from lm_eval.evaluator import evaluate +from lm_eval.tasks import get_task_dict + +from torchao.quantization import quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight +from torchao._models._eval import TransformerEvalWrapper + +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) + +from torchao.quantization.quant_api import autoquant + + +torch._inductor.config.force_fuse_int_mm_with_mul = True +torch._inductor.config.fx_graph_cache = True + + +def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length, sensi_bit, non_sensi_bit, quant_sym, group_size): + + tokenizer = AutoTokenizer.from_pretrained(repo_id) + model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) + + if quantization == "int8dq": + quantize_(model.to(device=device), int8_dynamic_activation_int4_weight()) + + elif quantization == "int8wo": + quantize_(model.to(device=device), int8_weight_only()) + + elif quantization == "int4wo": + quantize_(model.to(device=device), int4_weight_only(group_size=group_size)) + + elif quantization == "autoquant": + model = autoquant(model.to(device=device)) + + # naive implementation of uniform precision quantization all layers + elif quantization in ["2","3","4","5","6","8"]: + if quant_sym == "asym": + quantize_(model.to(device=device), intN_weight_only_asym(n=int(quantization), group_size=group_size)) + elif quant_sym == "sym": + quantize_(model.to(device=device), intN_weight_only_sym(n=int(quantization), group_size=group_size)) + + elif quantization == "MP_llama3": + + # filter for sensitive layers + def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool: + return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']) + + # filter for non-sensitive layers + def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool: + return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])) + + if sensi_bit != 16: + # quantize the sensitive layers + if sensi_bit == 8: + quantize_(model.to(device=device), int8_weight_only(), filter_fn_sen) + elif sensi_bit == 4: + quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_sen) + elif sensi_bit in [6,5,3,2]: + if quant_sym == "asym": + quantize_(model.to(device=device), intN_weight_only_asym(n=sensi_bit, group_size=group_size), filter_fn_sen) + elif quant_sym == "sym": + quantize_(model.to(device=device), intN_weight_only_sym(n=sensi_bit, group_size=group_size), filter_fn_sen) + + # quantize the less-sensitive layers + if non_sensi_bit == 8: + quantize_(model.to(device=device), int8_weight_only(), filter_fn_nonsen) + elif non_sensi_bit == 4: + quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_nonsen) + elif non_sensi_bit in [6,5,3,2]: + if sensi_bit == 4: + if quant_sym == "asym": + quantize_(model, intN_weight_only_asym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) + elif quant_sym == "sym": + quantize_(model, intN_weight_only_sym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) + else: + if quant_sym == "asym": + quantize_(model.to(device=device), intN_weight_only_asym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) + elif quant_sym == "sym": + quantize_(model.to(device=device), intN_weight_only_sym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) + + if compile: + model = torch.compile(model, mode="max-autotune", fullgraph=True) + + with torch.no_grad(): + + result = evaluate( + HFLM( + pretrained=model, + tokenizer=tokenizer, + batch_size=batch_size, + max_length=max_length), + get_task_dict(tasks), + limit = limit, + ) + + for task, res in result["results"].items(): + print(f"{task}: {res}") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Run HF Model Evaluation') + parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.') + parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') + parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') + parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') + parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') + parser.add_argument('-q', '--quantization', default = "None", help='Which quantization technique to apply') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') + parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') + parser.add_argument('--sensi_bit', type=int, default=16, help='Bit setting for sensitive layers') + parser.add_argument('--non_sensi_bit', type=int, default=16, help='Bit setting for non-sensitive layers') + parser.add_argument('--quant_sym', type=str, default="asym", help='symmetric or asymmetric quantization') + parser.add_argument('--group_size', type=int, default=32, help='group size to perform quantization on') + args = parser.parse_args() + run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.sensi_bit, args.non_sensi_bit, args.quant_sym, args.group_size) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py new file mode 100644 index 0000000000..2e854b2b59 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py @@ -0,0 +1,40 @@ +import torch + +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) + +def intN_weight_only_asym(group_size=32, n=8): + def apply_intN_weight_only_quant_asym(weight): + # avoid circular dep + from torchao.dtypes import to_affine_quantized + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int8 + quant_min = 0 + quant_max = 2**n-1 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain) + + return apply_intN_weight_only_quant_asym + +def intN_weight_only_sym(group_size=32, n=8): + def apply_intN_weight_only_quant_sym(weight): + # avoid circular dep + from torchao.dtypes import to_affine_quantized + mapping_type = MappingType.SYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int8 + quant_min = -2**(n-1) + quant_max = 2**(n-1)-1 + eps = 1e-6 + preserve_zero = True + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.INT + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain) + + return apply_intN_weight_only_quant_sym diff --git a/torchao/quantization/prototype/mixed_precision/scripts/quant_model_size.py b/torchao/quantization/prototype/mixed_precision/scripts/quant_model_size.py new file mode 100644 index 0000000000..1458019f38 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/quant_model_size.py @@ -0,0 +1,35 @@ +def quantized_model_size_in_bytes(num_elements, group_size, bit_zeropoint, bit_scale, x, y, A, B): + # Size for A-bit quantization layers + size_A_bit = x * (num_elements * A + num_elements // group_size * (bit_zeropoint + bit_scale)) + + # Size for B-bit quantization layers + size_B_bit = y * (num_elements * B + num_elements // group_size * (bit_zeropoint + bit_scale)) + + # Total quantized model size in bits + total_size_bits = size_A_bit + size_B_bit + + # Convert to bytes + total_size_bytes = total_size_bits / 8 + + # Convert to gigabytes + total_size_gb = total_size_bytes / (1024 ** 3) + + return total_size_gb + +# Example usage +num_elements = 250945664 #number of elements per Llama3 linear layer +group_size = 32 # Example value, please adjust as needed +bit_zeropoint = 2 # Example value, please adjust as needed +bit_scale = 2 # Example value, please adjust as needed +x = 32 # Example number of layers for A-bit quantization, adjust as needed +y = 0 # Example number of layers for B-bit quantization, adjust as needed +#A = 4 # Example bit width for A-bit quantization, adjust as needed +#B = 0 # Example bit width for B-bit quantization, adjust as needed + +#for b in [8]: +# model_size_bytes = quantized_model_size_in_bytes(num_elements, group_size, bit_zeropoint, bit_scale, 32, 0, b, 0) +# print(f"The quantized model size for {b} bits is {model_size_bytes} GB") + +for (x,y) in [(16,8),(16,6),(16,5),(16,4),(16,3),(16,2),(8,6),(8,5),(8,4),(8,3),(8,2),(6,5),(6,4),(6,3),(6,2),(5,4),(5,3),(5,2), (4,3),(4,2),(3,2)]: + model_size_bytes = quantized_model_size_in_bytes(num_elements, group_size, bit_zeropoint, bit_scale, 5, 27, x, y) + print(f"The quantized model size for {b} bits is {model_size_bytes} GB") diff --git a/torchao/quantization/prototype/mixed_precision/scripts/sensitivity_study.py b/torchao/quantization/prototype/mixed_precision/scripts/sensitivity_study.py new file mode 100644 index 0000000000..ca904041a5 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/sensitivity_study.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from lm_eval.models.huggingface import HFLM +from lm_eval.evaluator import evaluate +from lm_eval.tasks import get_task_dict + +from torchao.quantization import quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight +from torchao._models._eval import TransformerEvalWrapper + +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, +) + +from torchao.quantization.quant_api import ( + change_linear_weights_to_int4_woqtensors, + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + autoquant, +) + +torch._inductor.config.force_fuse_int_mm_with_mul = True +torch._inductor.config.fx_graph_cache = True + +def intN_weight_only(group_size=32, n=8): + def apply_intN_weight_only_quant(weight): + # avoid circular dep + from torchao.dtypes import to_affine_quantized + + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.uint8 + quant_min = 0 + quant_max = 2**n-1 + eps = 1e-6 + preserve_zero = False + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain) + + return apply_intN_weight_only_quant + + + +def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length, layer, linear_type): + + tokenizer = AutoTokenizer.from_pretrained(repo_id) + model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) + + def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool: + #return isinstance(child, nn.Linear) and "."+str(layer)+"." in cur_fqn + return isinstance(child, nn.Linear) and linear_type in cur_fqn and (".0." not in cur_fqn) and (".1." not in cur_fqn) and (".2." not in cur_fqn) and (".30." not in cur_fqn) and (".31." not in cur_fqn) + + if quantization in ["2","3","4","5","6","8"]: + quantize_(model.to(device=device), intN_weight_only(n=int(quantization)), filter_fn_sen) + + if compile: + model = torch.compile(model, mode="max-autotune", fullgraph=True) + + with torch.no_grad(): + + result = evaluate( + HFLM( + pretrained=model,#.to(device), + tokenizer=tokenizer, + batch_size=batch_size, + max_length=max_length), + get_task_dict(tasks), + limit = limit, + ) + + for task, res in result["results"].items(): + print(f"{task}: {res}") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Run HF Model Evaluation') + parser.add_argument('--repo_id', type=str, default="meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.') + parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') + parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') + parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') + parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') + parser.add_argument('-q', '--quantization', default = "None", choices=["2","3","4","5","6","8","MP_llama3", "None"], help='Which quantization technique to apply') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') + parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') + parser.add_argument('--layer', type=int, default=0, help='The layer to quantize') + parser.add_argument('--linear_type', type=str, default=0, choices=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], help='The linear type to quantize') + + args = parser.parse_args() + run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.layer, args.linear_type) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/test_naive_intNwo.py b/torchao/quantization/prototype/mixed_precision/scripts/test_naive_intNwo.py new file mode 100644 index 0000000000..35ce288cff --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/test_naive_intNwo.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn + +from naive_intNwo import intN_weight_only_asym, intN_weight_only_sym + +from torchao.quantization import quantize_ + +from torchao.quantization.utils import ( + _apply_logging_hook, + compute_error, + compute_error as SQNR, + _fqn_to_op_to_shape_to_count, + LoggingTensorMode, +) + +def test_weight_only_quant(quantization_bit=2): + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + x = torch.randn(*x_shape) + m = nn.Sequential(nn.Linear(4, 5)) + y_ref = m(x) + quantize_(m, intN_weight_only_asym(n=int(quantization_bit),group_size=2)) + y_wo = m(x) + sqnr = compute_error(y_ref, y_wo) + assert(sqnr > 44.0),"sqnr: {} is too low".format(sqnr) + +for i in [2,3,5,6]: + test_weight_only_quant(i) From 02ef81beb7ad12a51e1a5a24a06a1111dca28069 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Tue, 23 Jul 2024 14:24:01 -0700 Subject: [PATCH 02/12] remove experiment scripts --- test/quantization/test_naive_intNwo.py | 46 +++++++++++++ .../mixed_precision/scripts/mp_quant_eval.py | 66 +++++-------------- .../mixed_precision/scripts/naive_intNwo.py | 48 ++++++++++---- 3 files changed, 98 insertions(+), 62 deletions(-) create mode 100644 test/quantization/test_naive_intNwo.py diff --git a/test/quantization/test_naive_intNwo.py b/test/quantization/test_naive_intNwo.py new file mode 100644 index 0000000000..c25ad2b00c --- /dev/null +++ b/test/quantization/test_naive_intNwo.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn + +import os +import sys +# append the path to the naive_intNwo.py file +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "torchao/quantization/prototype/mixed_precision/scripts")) +from naive_intNwo import intN_weight_only + +from torchao.quantization import quantize_, int8_weight_only, int4_weight_only + +from torchao.quantization.utils import ( + _apply_logging_hook, + compute_error, + compute_error as SQNR, + _fqn_to_op_to_shape_to_count, + LoggingTensorMode, +) + +def test_weight_only_quant(quantization_bit=2, symmetric=False): + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + x = torch.randn(*x_shape) + m = nn.Sequential(nn.Linear(4, 5)) + y_ref = m(x) + quantize_(m, intN_weight_only(n=quantization_bit, group_size=2, symmetric=symmetric)) + y_wo = m(x) + sqnr = compute_error(y_ref, y_wo) + print(sqnr) + assert sqnr > 44.0, "sqnr: {} is too low".format(sqnr) + + +# test if the asymmetric and symmetric quantization API works with different bit widths +for i in range(2, 9): + #test for asymmetric quantization + try: + test_weight_only_quant(i, False) + print(f"Test passed for {i}-bit using naive intNwo asymmetric quantization implementation") + except Exception as e: + print(f"Exception handled in test loop for {i}-bit asymmetric quantization. Details: {e}") + + #test for symmetric quantization + try: + test_weight_only_quant(i, True) + print(f"Test passed for {i}-bit using naive intNwo symmetric quantization implementation") + except Exception as e: + print(f"Exception handled in test loop for {i}-bit symmetric quantization. Details: {e}") diff --git a/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py b/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py index 0caee20dae..d17b76159e 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from naive_intNwo import intN_weight_only_asym, intN_weight_only_sym +from naive_intNwo import intN_weight_only from transformers import AutoModelForCausalLM, AutoTokenizer from lm_eval.models.huggingface import HFLM @@ -28,63 +28,33 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi tokenizer = AutoTokenizer.from_pretrained(repo_id) model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) - if quantization == "int8dq": - quantize_(model.to(device=device), int8_dynamic_activation_int4_weight()) - - elif quantization == "int8wo": - quantize_(model.to(device=device), int8_weight_only()) - - elif quantization == "int4wo": - quantize_(model.to(device=device), int4_weight_only(group_size=group_size)) - - elif quantization == "autoquant": + if quantization == "autoquant": model = autoquant(model.to(device=device)) # naive implementation of uniform precision quantization all layers elif quantization in ["2","3","4","5","6","8"]: - if quant_sym == "asym": - quantize_(model.to(device=device), intN_weight_only_asym(n=int(quantization), group_size=group_size)) - elif quant_sym == "sym": - quantize_(model.to(device=device), intN_weight_only_sym(n=int(quantization), group_size=group_size)) - + quantize_(model.to(device=device), intN_weight_only(n=int(quantization), group_size=group_size, symmetric=quant_sym)) + + # mix precision quantization for Llama3 elif quantization == "MP_llama3": - # filter for sensitive layers + # filter for sensitive layers (the first 3 and last 2 layers for Llama3) def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool: return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']) - # filter for non-sensitive layers + # filter for non-sensitive layers (other 27 layers for Llama3) def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool: return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])) + # quantize the sensitive layers if sensi_bit != 16: - # quantize the sensitive layers - if sensi_bit == 8: - quantize_(model.to(device=device), int8_weight_only(), filter_fn_sen) - elif sensi_bit == 4: - quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_sen) - elif sensi_bit in [6,5,3,2]: - if quant_sym == "asym": - quantize_(model.to(device=device), intN_weight_only_asym(n=sensi_bit, group_size=group_size), filter_fn_sen) - elif quant_sym == "sym": - quantize_(model.to(device=device), intN_weight_only_sym(n=sensi_bit, group_size=group_size), filter_fn_sen) + quantize_(model.to(device=device), intN_weight_only(n=sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_sen) # quantize the less-sensitive layers - if non_sensi_bit == 8: - quantize_(model.to(device=device), int8_weight_only(), filter_fn_nonsen) - elif non_sensi_bit == 4: - quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_nonsen) - elif non_sensi_bit in [6,5,3,2]: - if sensi_bit == 4: - if quant_sym == "asym": - quantize_(model, intN_weight_only_asym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) - elif quant_sym == "sym": - quantize_(model, intN_weight_only_sym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) - else: - if quant_sym == "asym": - quantize_(model.to(device=device), intN_weight_only_asym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) - elif quant_sym == "sym": - quantize_(model.to(device=device), intN_weight_only_sym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) + if sensi_bit == 4: + quantize_(model, intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen) + else: + quantize_(model.to(device=device), intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen) if compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) @@ -113,13 +83,13 @@ def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool: parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument('-q', '--quantization', default = "None", help='Which quantization technique to apply') + parser.add_argument('-q', '--quantization', default = "None", choices = ["2", "3", "4", "5", "6", "8", "MP_llama3", "None"], help='Which quantization technique to apply, choose from ["2", "3", "4", "5", "6", "8"] for uniform quantizatoin, choose "MP_llama3" for mixed-precision for Llama3 and need to set corresponding sensi_bit and non_sensi_bit, choose "None" for no quantization') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') - parser.add_argument('--sensi_bit', type=int, default=16, help='Bit setting for sensitive layers') - parser.add_argument('--non_sensi_bit', type=int, default=16, help='Bit setting for non-sensitive layers') - parser.add_argument('--quant_sym', type=str, default="asym", help='symmetric or asymmetric quantization') - parser.add_argument('--group_size', type=int, default=32, help='group size to perform quantization on') + parser.add_argument('--sensi_bit', type=int, default=16, choices = [16, 8, 6, 5, 4, 3], help='Bit setting for sensitive layers') + parser.add_argument('--non_sensi_bit', type=int, default=8, choices = [8, 6, 5, 4, 3, 2], help='Bit setting for non-sensitive layers') + parser.add_argument('--quant_sym', type=bool, default=False, help='Symmetric or asymmetric quantization, asymmetric by default') + parser.add_argument('--group_size', type=int, default=32, help='Group size to perform quantization on') args = parser.parse_args() run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.sensi_bit, args.non_sensi_bit, args.quant_sym, args.group_size) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py index 2e854b2b59..7095b2c0ce 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py @@ -5,13 +5,26 @@ ZeroPointDomain, ) -def intN_weight_only_asym(group_size=32, n=8): +from torchao.quantization import int8_weight_only, int4_weight_only + + +def intN_weight_only(group_size=32, n=8, symmetric=False): + ''' + Apply int N-bit weight only quantization to a linear layer. + Args: + `groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32] + `n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2] + Usage: + from torchao.quantization import quantize_ + quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize) + ''' + # for asymmetric quantization def apply_intN_weight_only_quant_asym(weight): - # avoid circular dep + # avoid circular dependency from torchao.dtypes import to_affine_quantized mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) - target_dtype = torch.int8 + target_dtype = torch.uint8 quant_min = 0 quant_max = 2**n-1 eps = 1e-6 @@ -20,21 +33,28 @@ def apply_intN_weight_only_quant_asym(weight): zero_point_domain = ZeroPointDomain.FLOAT return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain) - return apply_intN_weight_only_quant_asym - -def intN_weight_only_sym(group_size=32, n=8): + # for symmetric quantization def apply_intN_weight_only_quant_sym(weight): - # avoid circular dep + # avoid circular dependency from torchao.dtypes import to_affine_quantized mapping_type = MappingType.SYMMETRIC block_size = (1, group_size) target_dtype = torch.int8 - quant_min = -2**(n-1) - quant_max = 2**(n-1)-1 eps = 1e-6 - preserve_zero = True - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.INT - return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain) + zero_point_dtype = torch.int64 + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - return apply_intN_weight_only_quant_sym + try: + assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]" + if n == 8: + return int8_weight_only() + elif n == 4: + return int4_weight_only(group_size=group_size) + else: + if symmetric: + return apply_intN_weight_only_quant_sym + else: + return apply_intN_weight_only_quant_asym + except Exception as e: + raise + From cf2c134a57ef00ae999c34c326ea5c0b8cbc5a48 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Wed, 24 Jul 2024 09:17:05 -0700 Subject: [PATCH 03/12] remove exp files --- .../run_layer_sensitive_study.sh | 22 - .../prototype/mixed_precision/run_mp_quant.sh | 30 -- .../mixed_precision/run_mp_quant_benchmark.sh | 9 - .../mixed_precision/run_sensi_linear_type.sh | 26 -- .../mixed_precision/run_uni_quant.sh | 39 -- .../run_uni_quant_benchmark.sh | 10 - .../mixed_precision/scripts/generate.py | 390 ------------------ .../scripts/quant_model_size.py | 35 -- .../scripts/sensitivity_study.py | 95 ----- .../scripts/test_naive_intNwo.py | 27 -- 10 files changed, 683 deletions(-) delete mode 100755 torchao/quantization/prototype/mixed_precision/run_layer_sensitive_study.sh delete mode 100755 torchao/quantization/prototype/mixed_precision/run_mp_quant.sh delete mode 100755 torchao/quantization/prototype/mixed_precision/run_mp_quant_benchmark.sh delete mode 100755 torchao/quantization/prototype/mixed_precision/run_sensi_linear_type.sh delete mode 100755 torchao/quantization/prototype/mixed_precision/run_uni_quant.sh delete mode 100755 torchao/quantization/prototype/mixed_precision/run_uni_quant_benchmark.sh delete mode 100644 torchao/quantization/prototype/mixed_precision/scripts/generate.py delete mode 100644 torchao/quantization/prototype/mixed_precision/scripts/quant_model_size.py delete mode 100644 torchao/quantization/prototype/mixed_precision/scripts/sensitivity_study.py delete mode 100644 torchao/quantization/prototype/mixed_precision/scripts/test_naive_intNwo.py diff --git a/torchao/quantization/prototype/mixed_precision/run_layer_sensitive_study.sh b/torchao/quantization/prototype/mixed_precision/run_layer_sensitive_study.sh deleted file mode 100755 index b5bc0f62d3..0000000000 --- a/torchao/quantization/prototype/mixed_precision/run_layer_sensitive_study.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -GPUS=(0 1 2 3 4 5) - -CONFIGS1=("2" "3" "4" "5" "6" "8") - -PYTHON_SCRIPT="scripts/sensitivity_study.py" - -for LAYER in {0..31}; do - for i in "${!GPUS[@]}"; do - GPU="${GPUS[$i]}" - CONFIG1="${CONFIGS1[$i]}" - - LOG_FILE="Sensi_${LAYER}_${CONFIG1}.txt" - - CUDA_VISIBLE_DEVICES=$GPU python $PYTHON_SCRIPT --repo_id=checkpoints/meta-llama/Meta-Llama-3-8B --quantization=$CONFIG1 --layer=$LAYER &>"$LOG_FILE" & - done - - wait -done - -echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/run_mp_quant.sh b/torchao/quantization/prototype/mixed_precision/run_mp_quant.sh deleted file mode 100755 index 263e0697d2..0000000000 --- a/torchao/quantization/prototype/mixed_precision/run_mp_quant.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash - -# List of GPUs to use -GPUS=(0 1 2 3 4 5) - -# List of configuration files -CONFIGS1=(8 8 8 8 8 16) -CONFIGS2=(6 5 4 3 2 8) - -#CONFIGS1=(16 16 16 16 16) -#CONFIGS2=(6 5 4 3 2) - -#CONFIGS1=(5 5 5 6 6 6 6 3) -#CONFIGS2=(4 3 2 5 4 3 2 2) - -PYTHON_SCRIPT="scripts/mp_quant_eval.py" - -for i in "${!GPUS[@]}"; do - GPU="${GPUS[$i]}" - CONFIG1="${CONFIGS1[$i]}" - CONFIG2="${CONFIGS2[$i]}" - - LOG_FILE="MP_${CONFIG1}_${CONFIG2}.txt" - - CUDA_VISIBLE_DEVICES=$GPU python $PYTHON_SCRIPT --repo_id=checkpoints/meta-llama/Meta-Llama-3-8B --quantization=MP_llama3 --sensi_bit=$CONFIG1 --non_sensi_bit=$CONFIG2 &>"$LOG_FILE" & -done - -wait - -echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/run_mp_quant_benchmark.sh b/torchao/quantization/prototype/mixed_precision/run_mp_quant_benchmark.sh deleted file mode 100755 index 653eac112f..0000000000 --- a/torchao/quantization/prototype/mixed_precision/run_mp_quant_benchmark.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -PYTHON_SCRIPT="scripts/generate.py" -python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=MP_llama3 --sensi_bit=5 --non_sensi_bit=4 --write_result mp_quant_benchmark_results.txt -python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=MP_llama3 --sensi_bit=4 --non_sensi_bit=3 --write_result mp_quant_benchmark_results.txt -python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=MP_llama3 --sensi_bit=5 --non_sensi_bit=3 --write_result mp_quant_benchmark_results.txt - - -echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/run_sensi_linear_type.sh b/torchao/quantization/prototype/mixed_precision/run_sensi_linear_type.sh deleted file mode 100755 index e7a6845aa4..0000000000 --- a/torchao/quantization/prototype/mixed_precision/run_sensi_linear_type.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -CONFIGS1=("q_proj" "k_proj" "v_proj" "o_proj" "gate_proj" "up_proj" "down_proj") -CONFIGS2=("2" "3" "4" "5" "6" "8") - -PYTHON_SCRIPT="scripts/sensitivity_study.py" - -GPUS=(0 1 2 3 4 5) - -for i in "${!CONFIGS1[@]}"; do - CONFIG1="${CONFIGS1[$i]}" - - for j in "${!CONFIGS2[@]}"; do - CONFIG2="${CONFIGS2[$j]}" - GPU="${GPUS[$j]}" - - LOG_FILE="Sensi_skipsensi_${CONFIG1}_${CONFIG2}.txt" - - CUDA_VISIBLE_DEVICES=$GPU python $PYTHON_SCRIPT --repo_id=checkpoints/meta-llama/Meta-Llama-3-8B --quantization=$CONFIG2 --linear_type=$CONFIG1 &>"$LOG_FILE" & - done - - wait -done - - -echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/run_uni_quant.sh b/torchao/quantization/prototype/mixed_precision/run_uni_quant.sh deleted file mode 100755 index 13ac196ce8..0000000000 --- a/torchao/quantization/prototype/mixed_precision/run_uni_quant.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash - -# List of GPUs to use -GPUS=(2 3 4 5 6 7) - -# List of configuration files -#CONFIGS1=("int6wo" "int5wo" "int2wo" "int3wo" "int8wo" "int4wo" "None" "autoquant") -CONFIGS1=("2" "3" "4" "5" "6" "8") -#CONFIGS1=(8 8 8 8 8 4 4 16) -#CONFIGS2=(6 5 4 3 2 3 2 8) - -#CONFIGS1=(16 16 16 16 16) -#CONFIGS2=(6 5 4 3 2) - -PYTHON_SCRIPT="scripts/mx_eval.py" - -for i in "${!GPUS[@]}"; do - GPU="${GPUS[$i]}" - CONFIG1="${CONFIGS1[$i]}" - - LOG_FILE="UNI_${CONFIG1}_SYM.txt" - - CUDA_VISIBLE_DEVICES=$GPU python $PYTHON_SCRIPT --repo_id=checkpoints/meta-llama/Meta-Llama-3-8B --quantization=$CONFIG1 --quant_sym=sym &>"$LOG_FILE" & -done - -wait - -for i in "${!GPUS[@]}"; do - GPU="${GPUS[$i]}" - CONFIG1="${CONFIGS1[$i]}" - - LOG_FILE="UNI_${CONFIG1}_ASYM.txt" - - CUDA_VISIBLE_DEVICES=$GPU python $PYTHON_SCRIPT --repo_id=checkpoints/meta-llama/Meta-Llama-3-8B --quantization=$CONFIG1 --quant_sym=asym &>"$LOG_FILE" & -done - -wait - -echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/run_uni_quant_benchmark.sh b/torchao/quantization/prototype/mixed_precision/run_uni_quant_benchmark.sh deleted file mode 100755 index 6786857076..0000000000 --- a/torchao/quantization/prototype/mixed_precision/run_uni_quant_benchmark.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash -PYTHON_SCRIPT="scripts/generate.py" -python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=int4wo --write_result uni_quant_benchmark_results.txt -python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=int8wo --write_result uni_quant_benchmark_results.txt -python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=2 --write_result uni_quant_benchmark_results.txt -python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=3 --write_result uni_quant_benchmark_results.txt -python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=5 --write_result uni_quant_benchmark_results.txt -python $PYTHON_SCRIPT --checkpoint_path=checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --compile --quantization=6 --write_result uni_quant_benchmark_results.txt - -echo "All processes are complete." diff --git a/torchao/quantization/prototype/mixed_precision/scripts/generate.py b/torchao/quantization/prototype/mixed_precision/scripts/generate.py deleted file mode 100644 index 09f0b9613c..0000000000 --- a/torchao/quantization/prototype/mixed_precision/scripts/generate.py +++ /dev/null @@ -1,390 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import sys -import time -from pathlib import Path -from typing import Optional, Tuple -from datetime import datetime -import torch -import torchao -import torch._dynamo.config -import torch._inductor.config -from torchao.utils import get_model_size_in_bytes - -import torch.nn as nn - -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) - -from naive_intNwo import intN_weight_only_asym, intN_weight_only_sym - -def device_sync(device): - if "cuda" in device: - torch.cuda.synchronize(device) - elif ("cpu" in device) or ("mps" in device): - pass - else: - print(f"device={device} is not yet suppported") - -default_device = 'cuda' if torch.cuda.is_available() else 'cpu' - -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -from torchao._models.llama.model import Transformer, prepare_inputs_for_model -from torchao._models.llama.tokenizer import get_tokenizer - -def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization - q = torch.empty_like(probs_sort).exponential_(1) - return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) - -def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): - logits = logits / max(temperature, 1e-5) - - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v.select(-1, -1).unsqueeze(-1) - logits = torch.where(logits < pivot, -float("Inf"), logits) - probs = torch.nn.functional.softmax(logits, dim=-1) - return probs - -def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): - probs = logits_to_probs(logits[0, -1], temperature, top_k) - idx_next = multinomial_sample_one_no_sync(probs) - return idx_next, probs - -def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: - # input_pos: [B, S] - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs)[0] - -def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [B, 1] - assert input_pos.shape[-1] == 1 - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs) - -def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): - new_tokens, new_probs = [], [] - for i in range(num_new_tokens): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here - next_token, next_prob = decode_one_token( - model, cur_token, input_pos, **sampling_kwargs - ) - input_pos += 1 - new_tokens.append(next_token.clone()) - callback(new_tokens[-1]) - new_probs.append(next_prob.clone()) - cur_token = next_token.view(1, -1) - - return new_tokens, new_probs - - -def model_forward(model, x, input_pos): - return model(x, input_pos) - -@torch.no_grad() -def generate( - model: Transformer, - prompt: torch.Tensor, - max_new_tokens: int, - *, - interactive: bool, - callback = lambda x: x, - **sampling_kwargs -) -> torch.Tensor: - """ - Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. - """ - - # create an empty tensor of the expected final shape and fill in the current tokens - device = prompt.device - T = prompt.numel() - T_new = T + max_new_tokens - seq = torch.empty(T_new, dtype=prompt.dtype, device=device) - seq[:T] = prompt.view(-1) - - # setup model cache - max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - # format model input - x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens) - - # execute prefill - next_token = prefill(model, x, input_pos, **sampling_kwargs).clone() - seq[T] = next_token - - input_pos = torch.tensor([T], device=device, dtype=torch.int) - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) - seq[T + 1:] = torch.cat(generated_tokens) - - return seq - -def encode_tokens(tokenizer, string, bos=True, device=default_device): - tokens = tokenizer.encode(string) - if bos: - tokens = [tokenizer.bos_id()] + tokens - return torch.tensor(tokens, dtype=torch.int, device=device) - -def _load_model(checkpoint_path, device, precision): - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) - if "model" in checkpoint and "stories" in str(checkpoint_path): - checkpoint = checkpoint["model"] - - model = Transformer.from_name(checkpoint_path.parent.name) - model.load_state_dict(checkpoint, assign=True) - model = model.to(device=device, dtype=precision) - - return model.eval() - -B_INST, E_INST = "[INST]", "[/INST]" - -def main( - prompt: str = "Hello, my name is", - interactive: bool = False, - num_samples: int = 5, - max_new_tokens: int = 100, - top_k: int = 200, - temperature: float = 0.8, - checkpoint_path: Path = Path("checkpoints/meta-llama/Meta-Llama-3-8B/model.pth"), - quantization: Optional[str] = None, - compile: bool = True, - compile_prefill: bool = False, - profile: Optional[Path] = None, - device=default_device, - precision=torch.bfloat16, - write_result: Optional[Path] = None, - sensi_bit:Optional[int] = None, - non_sensi_bit:Optional[int] = None, - compile_mode:Optional[str] = "reduce-overhead", - group_size:Optional[int] = 32, -) -> None: - """Generates text samples based on a pre-trained Transformer model and tokenizer. - """ - - torchao.quantization.utils.recommended_inductor_config_setter() - - assert checkpoint_path.is_file(), checkpoint_path - tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), str(tokenizer_path) - - print(f"Using device={device}") - is_chat = "chat" in str(checkpoint_path) - - print("Loading model ...") - t0 = time.time() - model = _load_model(checkpoint_path, device, precision) - - - device_sync(device=device) # MKG - print(f"Time to load model: {time.time() - t0:.02f} seconds") - - tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) - - encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - prompt_length = encoded.size(0) - - torch.manual_seed(1234) - - if quantization: - from torchao.quantization.quant_api import ( - quantize_, - int8_weight_only, - int8_dynamic_activation_int8_weight, - int4_weight_only, - autoquant, - unwrap_tensor_subclass - ) - if "int8wo" in quantization: - quantize_(model, int8_weight_only_my()) - if "int8dq" in quantization: - quantize_(model, int8_dynamic_activation_int8_weight()) - if "int4wo" in quantization: - groupsize=int(quantization.split("-")[-1]) - assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - quantize_(model, int4_weight_only(group_size=groupsize)) - if "autoquant" == quantization: - model = autoquant(model, manual=True) - - if quantization in ["2","3","4","5","6","8"]: - quantize_(model.to(device=device), intN_weight_only_sym(n=int(quantization), group_size=group_size)) - - elif quantization == "MP_llama3": - # filter for sensitive layers - def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool: - return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']) - - # filter for non-sensitive layers - def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool: - return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])) - - #quantize the sensitive layers - if sensi_bit == 8: - quantize_(model.to(device=device), int8_weight_only(), filter_fn_sen) - elif sensi_bit == 4: - quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_sen) - elif sensi_bit in [6,5,3,2]: - quantize_(model.to(device=device), intN_weight_only_asym(n=sensi_bit, group_size=group_size), filter_fn_sen) - - #quantize the less-sensitive layers - if non_sensi_bit == 8: - quantize_(model.to(device=device), int8_weight_only(), filter_fn_nonsen) - elif non_sensi_bit == 4: - quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_nonsen) - elif non_sensi_bit in [6,5,3,2]: - if sensi_bit == 4: - quantize_(model, intN_weight_only_asym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) - else: - quantize_(model.to(device=device), intN_weight_only_asym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen) - - else: - unwrap_tensor_subclass(model) - - model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 - - if compile: - print("Compiling Model") - global decode_one_token, prefill - decode_one_token = torch.compile(decode_one_token, mode=compile_mode, fullgraph=True) - - if compile_prefill: - prefill = torch.compile(prefill, fullgraph=True, dynamic=True) - - - aggregate_metrics = { - 'tokens_per_sec': [], - } - start = -1 if compile else 0 - - for i in range(start, num_samples): - if i==0: - torch.cuda.reset_peak_memory_stats() - device_sync(device=device) # MKG - if i >= 0 and interactive: - prompt = input("What is your prompt? ") - if is_chat: - prompt = f"{B_INST} {prompt.strip()} {E_INST}" - encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - - if interactive and i >= 0: - buffer = [] - period_id = tokenizer.encode('.')[0] - done_generating = False - def callback(x): - nonlocal done_generating - if done_generating: - return - buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) - if x.item() == tokenizer.eos_id(): - done_generating = True - if len(buffer) == 4 or done_generating: - print(''.join(buffer), end='', flush=True) - buffer.clear() - # print(, end='', flush=True) - else: - callback = lambda x : x - t0 = time.perf_counter() - import contextlib - if (i != num_samples - 1 or not profile): - prof = contextlib.nullcontext() - else: - torch.profiler._utils._init_for_cuda_graphs() - prof = torch.profiler.profile() - with prof: - y = generate( - model, - encoded, - max_new_tokens, - interactive=interactive, - callback=callback, - temperature=temperature, - top_k=top_k, - ) - - if i == -1: - compile_time = time.perf_counter() - t0 - print(f"Compilation time: {compile_time} seconds") - continue - if hasattr(prof, "export_chrome_trace"): - prof.export_chrome_trace(f"{profile}.json") - device_sync(device=device) # MKG - t = time.perf_counter() - t0 - - if not interactive: - print(tokenizer.decode(y.tolist())) - else: - print() - tokens_generated = y.size(0) - prompt_length - tokens_sec = tokens_generated / t - aggregate_metrics['tokens_per_sec'].append(tokens_sec) - print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") - print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") - print("==========") - - tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() - bandwidth = model_size * tokpersec - mem = torch.cuda.max_memory_reserved() /1e9 - print(f"Average tokens/sec: {tokpersec:.2f}") - print(f"Average Bandwidth: {bandwidth:.02f} GB/s") - print(f"Peak Memory Usage: {mem:.02f} GB") - print(f"Model Size: {model_size:.02f} GB") - if write_result: - result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " - result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device}, compilation time: {compile_time} " - result_txt += f"repro: python generate.py " - result_txt += f"--quantization {quantization} " if quantization else "" - result_txt += f"--sensi_bit {sensi_bit} " if sensi_bit else "" - result_txt += f"--non_sensi_bit {non_sensi_bit} " if non_sensi_bit else "" - result_txt += f"--checkpoint_path {checkpoint_path} " - result_txt += f"--device {device} " - result_txt += f"--precision {precision} " - result_txt += f"--compile " if compile else "" - result_txt += f"--compile_prefill " if compile_prefill else "" - result_txt += f"--profile {profile} " if profile else "" - result_txt += f"--interactive " if interactive else "" - result_txt += f"--num_samples {num_samples} " - result_txt += f"--max_new_tokens {max_new_tokens} " - result_txt += f"--top_k {top_k} " - result_txt += f"--temperature {temperature} " - f=open(write_result, "a") - f.write(result_txt) - f.close() - - - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') - - parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') - parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') - parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') - parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') - parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') - parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Meta-Llama-3-8B/model.pth"), help='Model checkpoint path.') - #parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') - parser.add_argument('-q', '--quantization', default = "None", help='Which quantization technique to apply') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') - parser.add_argument('--profile', type=Path, default=None, help='Profile path.') - parser.add_argument('--device', type=str, default=default_device, help='Device to use') - parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') - parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result') - parser.add_argument('--sensi_bit', type=int, default=16, help='Bit setting for sensitive layers') - parser.add_argument('--non_sensi_bit', type=int, default=16, help='Bit setting for non-sensitive layers') - parser.add_argument('--compile_mode', type=str, default="max-autotune", help='max-autotune or reduce-overhead mode for torch.compile()') - parser.add_argument('--group_size', type=int, default=32, help='group size to perform quantization on') - args = parser.parse_args() - main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result, args.sensi_bit, args.non_sensi_bit, args.compile_mode, args.group_size - ) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/quant_model_size.py b/torchao/quantization/prototype/mixed_precision/scripts/quant_model_size.py deleted file mode 100644 index 1458019f38..0000000000 --- a/torchao/quantization/prototype/mixed_precision/scripts/quant_model_size.py +++ /dev/null @@ -1,35 +0,0 @@ -def quantized_model_size_in_bytes(num_elements, group_size, bit_zeropoint, bit_scale, x, y, A, B): - # Size for A-bit quantization layers - size_A_bit = x * (num_elements * A + num_elements // group_size * (bit_zeropoint + bit_scale)) - - # Size for B-bit quantization layers - size_B_bit = y * (num_elements * B + num_elements // group_size * (bit_zeropoint + bit_scale)) - - # Total quantized model size in bits - total_size_bits = size_A_bit + size_B_bit - - # Convert to bytes - total_size_bytes = total_size_bits / 8 - - # Convert to gigabytes - total_size_gb = total_size_bytes / (1024 ** 3) - - return total_size_gb - -# Example usage -num_elements = 250945664 #number of elements per Llama3 linear layer -group_size = 32 # Example value, please adjust as needed -bit_zeropoint = 2 # Example value, please adjust as needed -bit_scale = 2 # Example value, please adjust as needed -x = 32 # Example number of layers for A-bit quantization, adjust as needed -y = 0 # Example number of layers for B-bit quantization, adjust as needed -#A = 4 # Example bit width for A-bit quantization, adjust as needed -#B = 0 # Example bit width for B-bit quantization, adjust as needed - -#for b in [8]: -# model_size_bytes = quantized_model_size_in_bytes(num_elements, group_size, bit_zeropoint, bit_scale, 32, 0, b, 0) -# print(f"The quantized model size for {b} bits is {model_size_bytes} GB") - -for (x,y) in [(16,8),(16,6),(16,5),(16,4),(16,3),(16,2),(8,6),(8,5),(8,4),(8,3),(8,2),(6,5),(6,4),(6,3),(6,2),(5,4),(5,3),(5,2), (4,3),(4,2),(3,2)]: - model_size_bytes = quantized_model_size_in_bytes(num_elements, group_size, bit_zeropoint, bit_scale, 5, 27, x, y) - print(f"The quantized model size for {b} bits is {model_size_bytes} GB") diff --git a/torchao/quantization/prototype/mixed_precision/scripts/sensitivity_study.py b/torchao/quantization/prototype/mixed_precision/scripts/sensitivity_study.py deleted file mode 100644 index ca904041a5..0000000000 --- a/torchao/quantization/prototype/mixed_precision/scripts/sensitivity_study.py +++ /dev/null @@ -1,95 +0,0 @@ -import torch -import torch.nn as nn - -from transformers import AutoModelForCausalLM, AutoTokenizer - -from lm_eval.models.huggingface import HFLM -from lm_eval.evaluator import evaluate -from lm_eval.tasks import get_task_dict - -from torchao.quantization import quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight -from torchao._models._eval import TransformerEvalWrapper - -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) - -from torchao.quantization.quant_api import ( - change_linear_weights_to_int4_woqtensors, - change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, - autoquant, -) - -torch._inductor.config.force_fuse_int_mm_with_mul = True -torch._inductor.config.fx_graph_cache = True - -def intN_weight_only(group_size=32, n=8): - def apply_intN_weight_only_quant(weight): - # avoid circular dep - from torchao.dtypes import to_affine_quantized - - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.uint8 - quant_min = 0 - quant_max = 2**n-1 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain) - - return apply_intN_weight_only_quant - - - -def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length, layer, linear_type): - - tokenizer = AutoTokenizer.from_pretrained(repo_id) - model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) - - def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool: - #return isinstance(child, nn.Linear) and "."+str(layer)+"." in cur_fqn - return isinstance(child, nn.Linear) and linear_type in cur_fqn and (".0." not in cur_fqn) and (".1." not in cur_fqn) and (".2." not in cur_fqn) and (".30." not in cur_fqn) and (".31." not in cur_fqn) - - if quantization in ["2","3","4","5","6","8"]: - quantize_(model.to(device=device), intN_weight_only(n=int(quantization)), filter_fn_sen) - - if compile: - model = torch.compile(model, mode="max-autotune", fullgraph=True) - - with torch.no_grad(): - - result = evaluate( - HFLM( - pretrained=model,#.to(device), - tokenizer=tokenizer, - batch_size=batch_size, - max_length=max_length), - get_task_dict(tasks), - limit = limit, - ) - - for task, res in result["results"].items(): - print(f"{task}: {res}") - - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='Run HF Model Evaluation') - parser.add_argument('--repo_id', type=str, default="meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.') - parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') - parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') - parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') - parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument('-q', '--quantization', default = "None", choices=["2","3","4","5","6","8","MP_llama3", "None"], help='Which quantization technique to apply') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') - parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') - parser.add_argument('--layer', type=int, default=0, help='The layer to quantize') - parser.add_argument('--linear_type', type=str, default=0, choices=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], help='The linear type to quantize') - - args = parser.parse_args() - run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.layer, args.linear_type) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/test_naive_intNwo.py b/torchao/quantization/prototype/mixed_precision/scripts/test_naive_intNwo.py deleted file mode 100644 index 35ce288cff..0000000000 --- a/torchao/quantization/prototype/mixed_precision/scripts/test_naive_intNwo.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch -import torch.nn as nn - -from naive_intNwo import intN_weight_only_asym, intN_weight_only_sym - -from torchao.quantization import quantize_ - -from torchao.quantization.utils import ( - _apply_logging_hook, - compute_error, - compute_error as SQNR, - _fqn_to_op_to_shape_to_count, - LoggingTensorMode, -) - -def test_weight_only_quant(quantization_bit=2): - for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: - x = torch.randn(*x_shape) - m = nn.Sequential(nn.Linear(4, 5)) - y_ref = m(x) - quantize_(m, intN_weight_only_asym(n=int(quantization_bit),group_size=2)) - y_wo = m(x) - sqnr = compute_error(y_ref, y_wo) - assert(sqnr > 44.0),"sqnr: {} is too low".format(sqnr) - -for i in [2,3,5,6]: - test_weight_only_quant(i) From 1055f1489154475716d87d811d7dd6e183f74926 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Wed, 24 Jul 2024 10:56:10 -0700 Subject: [PATCH 04/12] use default ZeroPointDomain.INT for int2/3/5/6 --- .../prototype/mixed_precision/scripts/naive_intNwo.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py index 7095b2c0ce..7a68b70184 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py @@ -28,10 +28,10 @@ def apply_intN_weight_only_quant_asym(weight): quant_min = 0 quant_max = 2**n-1 eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain) + preserve_zero = True + zero_point_dtype = torch.int64 + zero_point_domain = ZeroPointDomain.INT + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype)#, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain) # for symmetric quantization def apply_intN_weight_only_quant_sym(weight): From c00b16d5a967c62b54b7f7751ce6ba4799b2db71 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Wed, 24 Jul 2024 11:03:48 -0700 Subject: [PATCH 05/12] renamed test_naive_intNwo.py to test_mixed_precision.py --- .../{test_naive_intNwo.py => test_mixed_precision.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/quantization/{test_naive_intNwo.py => test_mixed_precision.py} (100%) diff --git a/test/quantization/test_naive_intNwo.py b/test/quantization/test_mixed_precision.py similarity index 100% rename from test/quantization/test_naive_intNwo.py rename to test/quantization/test_mixed_precision.py From f765eef162be6da4c47100e1b1cd3f71732b824e Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Thu, 25 Jul 2024 11:37:07 -0700 Subject: [PATCH 06/12] updated intNwo with _get_linear_subclass_inserter --- .../prototype/mixed_precision/scripts/naive_intNwo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py index 7a68b70184..6ebe458a46 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py @@ -6,7 +6,7 @@ ) from torchao.quantization import int8_weight_only, int4_weight_only - +from torchao.quantization.quant_api import _get_linear_subclass_inserter def intN_weight_only(group_size=32, n=8, symmetric=False): ''' @@ -52,9 +52,9 @@ def apply_intN_weight_only_quant_sym(weight): return int4_weight_only(group_size=group_size) else: if symmetric: - return apply_intN_weight_only_quant_sym + return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym) else: - return apply_intN_weight_only_quant_asym + return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym) except Exception as e: raise From 9a343a452ee1945492f28715885f39c0ab275fa5 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Thu, 25 Jul 2024 13:32:10 -0700 Subject: [PATCH 07/12] adjust sqnr threshold according to bit width --- test/quantization/test_mixed_precision.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/quantization/test_mixed_precision.py b/test/quantization/test_mixed_precision.py index c25ad2b00c..79bb0db253 100644 --- a/test/quantization/test_mixed_precision.py +++ b/test/quantization/test_mixed_precision.py @@ -18,19 +18,20 @@ ) def test_weight_only_quant(quantization_bit=2, symmetric=False): - for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + for x_shape in [[32, 64], [80, 80, 80, 64], [16, 64, 64]]: x = torch.randn(*x_shape) - m = nn.Sequential(nn.Linear(4, 5)) + m = nn.Sequential(nn.Linear(64, 80)) y_ref = m(x) - quantize_(m, intN_weight_only(n=quantization_bit, group_size=2, symmetric=symmetric)) + quantize_(m, intN_weight_only(n=quantization_bit, group_size=16, symmetric=symmetric)) y_wo = m(x) sqnr = compute_error(y_ref, y_wo) - print(sqnr) - assert sqnr > 44.0, "sqnr: {} is too low".format(sqnr) + #SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization + #e.g., we set sqnr threshold = 44 for 8-bit, so that 6.02 * 8= 48.16 fullfills + assert sqnr > 44.0-(8-quantization_bit)*6.02, "sqnr: {} is too low".format(sqnr) # test if the asymmetric and symmetric quantization API works with different bit widths -for i in range(2, 9): +for i in [2,3,5,6,8]: #test for asymmetric quantization try: test_weight_only_quant(i, False) From aafe38ede22bbaa7116e1aa46ddacd20ea6148a8 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Thu, 25 Jul 2024 15:37:12 -0700 Subject: [PATCH 08/12] fixed test for int4wo and add __init__.py --- test/quantization/test_mixed_precision.py | 16 ++++++---------- .../prototype/mixed_precision/__init__.py | 0 .../mixed_precision/scripts/__init__.py | 1 + 3 files changed, 7 insertions(+), 10 deletions(-) create mode 100644 torchao/quantization/prototype/mixed_precision/__init__.py create mode 100644 torchao/quantization/prototype/mixed_precision/scripts/__init__.py diff --git a/test/quantization/test_mixed_precision.py b/test/quantization/test_mixed_precision.py index 79bb0db253..d3a1633a25 100644 --- a/test/quantization/test_mixed_precision.py +++ b/test/quantization/test_mixed_precision.py @@ -1,11 +1,7 @@ import torch import torch.nn as nn -import os -import sys -# append the path to the naive_intNwo.py file -sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "torchao/quantization/prototype/mixed_precision/scripts")) -from naive_intNwo import intN_weight_only +from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only from torchao.quantization import quantize_, int8_weight_only, int4_weight_only @@ -18,11 +14,11 @@ ) def test_weight_only_quant(quantization_bit=2, symmetric=False): - for x_shape in [[32, 64], [80, 80, 80, 64], [16, 64, 64]]: - x = torch.randn(*x_shape) - m = nn.Sequential(nn.Linear(64, 80)) + for x_shape in [[64, 32], [80, 80, 80, 32], [16, 64, 32]]: + x = torch.randn(*x_shape, dtype=torch.bfloat16) + m = nn.Sequential(nn.Linear(32, 80)).bfloat16() y_ref = m(x) - quantize_(m, intN_weight_only(n=quantization_bit, group_size=16, symmetric=symmetric)) + quantize_(m, intN_weight_only(n=quantization_bit, group_size=32, symmetric=symmetric)) y_wo = m(x) sqnr = compute_error(y_ref, y_wo) #SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization @@ -31,7 +27,7 @@ def test_weight_only_quant(quantization_bit=2, symmetric=False): # test if the asymmetric and symmetric quantization API works with different bit widths -for i in [2,3,5,6,8]: +for i in [2, 3, 4, 5, 6, 8]: #test for asymmetric quantization try: test_weight_only_quant(i, False) diff --git a/torchao/quantization/prototype/mixed_precision/__init__.py b/torchao/quantization/prototype/mixed_precision/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/prototype/mixed_precision/scripts/__init__.py b/torchao/quantization/prototype/mixed_precision/scripts/__init__.py new file mode 100644 index 0000000000..1b0cae6ab3 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/__init__.py @@ -0,0 +1 @@ +from .naive_intNwo import intN_weight_only From 1bfa370d50ae96b9e46a17fbc156d917e8c0827e Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Tue, 30 Jul 2024 09:03:31 -0700 Subject: [PATCH 09/12] skip test_aq_int8_weight_only_quant_3_subclass due to seg fault on nightly --- test/integration/test_integration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index d8b6d71a51..f0a20ed51f 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -730,6 +730,7 @@ def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skip("skipping for now due to seg fault on nightly") def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype From f4fccf3dc70bde806f68569faf3e35821604e773 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Tue, 30 Jul 2024 09:41:28 -0700 Subject: [PATCH 10/12] edit the sqnr threshold --- test/quantization/test_mixed_precision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_mixed_precision.py b/test/quantization/test_mixed_precision.py index d3a1633a25..725307a93a 100644 --- a/test/quantization/test_mixed_precision.py +++ b/test/quantization/test_mixed_precision.py @@ -22,8 +22,8 @@ def test_weight_only_quant(quantization_bit=2, symmetric=False): y_wo = m(x) sqnr = compute_error(y_ref, y_wo) #SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization - #e.g., we set sqnr threshold = 44 for 8-bit, so that 6.02 * 8= 48.16 fullfills - assert sqnr > 44.0-(8-quantization_bit)*6.02, "sqnr: {} is too low".format(sqnr) + #e.g., we set sqnr threshold = 42 for 8-bit, so that 6.02 * 8= 48.16 fullfills + assert sqnr > 42.0-(8-quantization_bit)*6.02, "sqnr: {} is too low".format(sqnr) # test if the asymmetric and symmetric quantization API works with different bit widths From 8e787b64c326d1e89811973b9e8112c086385881 Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Wed, 31 Jul 2024 22:05:49 -0700 Subject: [PATCH 11/12] add unittest --- test/integration/test_integration.py | 1 - test/quantization/test_mixed_precision.py | 59 +++++++++-------------- 2 files changed, 24 insertions(+), 36 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index f0a20ed51f..d8b6d71a51 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -730,7 +730,6 @@ def test_aq_int8_weight_only_quant_2_subclass(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skip("skipping for now due to seg fault on nightly") def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQWeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype diff --git a/test/quantization/test_mixed_precision.py b/test/quantization/test_mixed_precision.py index 725307a93a..ba60be224b 100644 --- a/test/quantization/test_mixed_precision.py +++ b/test/quantization/test_mixed_precision.py @@ -1,43 +1,32 @@ +import unittest + import torch import torch.nn as nn - -from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only - from torchao.quantization import quantize_, int8_weight_only, int4_weight_only +from torchao.quantization.utils import compute_error +from torchao.quantization.prototype.mixed_precision.naive_intNwo import intN_weight_only -from torchao.quantization.utils import ( - _apply_logging_hook, - compute_error, - compute_error as SQNR, - _fqn_to_op_to_shape_to_count, - LoggingTensorMode, -) +_CUDA_IS_AVAILABLE = torch.cuda.is_available() -def test_weight_only_quant(quantization_bit=2, symmetric=False): - for x_shape in [[64, 32], [80, 80, 80, 32], [16, 64, 32]]: - x = torch.randn(*x_shape, dtype=torch.bfloat16) - m = nn.Sequential(nn.Linear(32, 80)).bfloat16() - y_ref = m(x) - quantize_(m, intN_weight_only(n=quantization_bit, group_size=32, symmetric=symmetric)) - y_wo = m(x) - sqnr = compute_error(y_ref, y_wo) - #SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization - #e.g., we set sqnr threshold = 42 for 8-bit, so that 6.02 * 8= 48.16 fullfills - assert sqnr > 42.0-(8-quantization_bit)*6.02, "sqnr: {} is too low".format(sqnr) +class TestWeightOnlyQuantNaive(unittest.TestCase): + def test_quantization_intNwo(self): + #skip test int4wo for now since it is under development in torchao + for quantization_bit in [2, 3, 5, 6, 8]: + for symmetric in [False, True]: + with self.subTest(quantization_bit=quantization_bit, symmetric=symmetric): + for x_shape in [[64, 32], [80, 80, 80, 32], [16, 64, 32]]: + x = torch.randn(*x_shape, dtype=torch.bfloat16) + m = nn.Sequential(nn.Linear(32, 80)).bfloat16() + y_ref = m(x) + quantize_(m, intN_weight_only(n=quantization_bit, group_size=32, symmetric=symmetric)) + y_wo = m(x) + sqnr = compute_error(y_ref, y_wo) + # SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization + # e.g., we set sqnr threshold = 44 for 8-bit, so that 6.02 * 8= 48.16 fullfills + expected_sqnr_threshold = 44.0 - (8 - quantization_bit) * 6.02 + self.assertGreater(sqnr, expected_sqnr_threshold, f"sqnr: {sqnr} is too low") -# test if the asymmetric and symmetric quantization API works with different bit widths -for i in [2, 3, 4, 5, 6, 8]: - #test for asymmetric quantization - try: - test_weight_only_quant(i, False) - print(f"Test passed for {i}-bit using naive intNwo asymmetric quantization implementation") - except Exception as e: - print(f"Exception handled in test loop for {i}-bit asymmetric quantization. Details: {e}") - #test for symmetric quantization - try: - test_weight_only_quant(i, True) - print(f"Test passed for {i}-bit using naive intNwo symmetric quantization implementation") - except Exception as e: - print(f"Exception handled in test loop for {i}-bit symmetric quantization. Details: {e}") +if __name__ == '__main__': + unittest.main() From e516f0b27c53df2d8540784a67d1a69817e17b3c Mon Sep 17 00:00:00 2001 From: Hanxian97 Date: Wed, 31 Jul 2024 22:13:11 -0700 Subject: [PATCH 12/12] correct import path --- test/quantization/test_mixed_precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantization/test_mixed_precision.py b/test/quantization/test_mixed_precision.py index ba60be224b..8afd022d3c 100644 --- a/test/quantization/test_mixed_precision.py +++ b/test/quantization/test_mixed_precision.py @@ -4,7 +4,7 @@ import torch.nn as nn from torchao.quantization import quantize_, int8_weight_only, int4_weight_only from torchao.quantization.utils import compute_error -from torchao.quantization.prototype.mixed_precision.naive_intNwo import intN_weight_only +from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only _CUDA_IS_AVAILABLE = torch.cuda.is_available()