From e531daad8c999dcc42eb386d45dd6ebeb0a3ad97 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Fri, 2 Aug 2024 09:08:17 -0400 Subject: [PATCH] adding kv_cache quantization (#532) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adding support for kv_cache quantization, we are using simple symmetric quantization, though using the full precision of the k and v values of the current token. we see tok/s reduction of 3-5 tok/s depending on context length. image and a reduction in peak memory image We expect this reduction to scale to large context lengths, in the model memory trace we can see the point where we replace the bf16 cache with the int8 cache which visually saves about half of the used memory Screenshot 2024-08-02 at 2 45 14 AM at longer context lengths both quantized and non-quantized kv_cache models start outputing weird stuff but otherwise accuracy of the kv_cache quant looks reasonable though e.g. for 2048 context length: <|begin_of_text|>Hello, my name is Richard Brown and I have been a professional musician for over 25 years. I have played in a number of bands, doing a wide variety of genres (soul/funk, rock, jazz, blues, latin, world). I have played on over a hundred albums so far. I have played with many different singers, as well as instrumentalists (guitarists, sax players, brass players, etc.). I love to play and try to learn as much as I can from others. I have become an all-round musician - playing keyboards, drums, programming, arranging; as well as writing songs myself. I have my own studio, and I can do sessions online. I also have my own website, where you can find out more about me and my music. I hope that you will find the music that you are looking for here. Otherwise there are some fixes in generate.py to get things working for large context lengths without overflowing beyond the model limit. test plan: sh benchmarks.sh (specifically the last 6 rows of benchmark_results.txt) --- torchao/_models/llama/benchmark_results.txt | 10 +++++ torchao/_models/llama/benchmarks.sh | 8 ++++ torchao/_models/llama/generate.py | 43 ++++++++++++++++----- torchao/_models/llama/model.py | 38 ++++++++++++++++++ 4 files changed, 89 insertions(+), 10 deletions(-) diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index b02d4c2441..e07a3e6799 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -1,3 +1,4 @@ +llama 2 20240619101342, tok/s= 29.85, mem/s= 788.87 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619101537, tok/s= 26.38, mem/s= 348.57 GB/s, peak_mem=13.62 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619105331, tok/s=106.55, mem/s=1408.06 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 @@ -8,6 +9,7 @@ 20240619110248, tok/s=199.86, mem/s= 746.66 GB/s, peak_mem= 4.50 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619114518, tok/s=159.22, mem/s=1069.87 GB/s, peak_mem= 8.91 GB, model_size= 6.72 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +llama 3 20240619114732, tok/s= 30.46, mem/s= 914.43 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619114939, tok/s= 26.56, mem/s= 398.65 GB/s, peak_mem=16.16 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619122811, tok/s= 96.09, mem/s=1442.32 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 @@ -17,3 +19,11 @@ 20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +kv cache quantization: +20240801093317, tok/s= 95.52, mem/s=1433.80 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801093529, tok/s= 92.36, mem/s=1386.35 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240801093944, tok/s= 89.88, mem/s=1349.13 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 +20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 +20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 20a7bf1103..6dd9c10d94 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -22,3 +22,11 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt + +export MODEL_REPO=meta-llama/Meta-Llama-3-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192 diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 34ff9abb12..6e6db90571 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -68,10 +68,11 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc next_token, next_prob = decode_one_token( model, cur_token, input_pos, **sampling_kwargs ) + next_token, next_prob = next_token.clone(), next_prob.clone() input_pos += 1 - new_tokens.append(next_token.clone()) + new_tokens.append(next_token) callback(new_tokens[-1]) - new_probs.append(next_prob.clone()) + new_probs.append(next_prob) cur_token = next_token.view(1, -1) return new_tokens, new_probs @@ -88,6 +89,7 @@ def generate( *, interactive: bool, callback = lambda x: x, + kv_cache_quantization: bool = False, **sampling_kwargs ) -> torch.Tensor: """ @@ -97,14 +99,27 @@ def generate( # create an empty tensor of the expected final shape and fill in the current tokens device = prompt.device T = prompt.numel() - T_new = T + max_new_tokens - seq = torch.empty(T_new, dtype=prompt.dtype, device=device) + + # calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size) + max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + new_tokens = max_seq_length - T + + # full prompt+output will be stored in seq + seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device) seq[:T] = prompt.view(-1) - # setup model cache - max_seq_length = min(T_new, model.config.block_size) if not interactive else 350 + # setup model caches with torch.device(device): model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if kv_cache_quantization: + from model import AffineQuantizedKVCache + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + _replace_with_custom_fn_if_matches_filter( + model, + AffineQuantizedKVCache.from_float, + lambda x, y: isinstance(x, torchao._models.llama.model.KVCache), + ) + # format model input x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens) @@ -113,8 +128,9 @@ def generate( next_token = prefill(model, x, input_pos, **sampling_kwargs).clone() seq[T] = next_token + # execute token generation input_pos = torch.tensor([T], device=device, dtype=torch.int) - generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) seq[T + 1:] = torch.cat(generated_tokens) return seq @@ -147,6 +163,7 @@ def main( temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), quantization: Optional[str] = None, + kv_cache_quantization: bool = False, compile: bool = True, compile_prefill: bool = False, profile: Optional[Path] = None, @@ -276,6 +293,7 @@ def callback(x): callback=callback, temperature=temperature, top_k=top_k, + kv_cache_quantization=kv_cache_quantization, ) if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") @@ -286,7 +304,10 @@ def callback(x): t = time.perf_counter() - t0 if not interactive: - print(tokenizer.decode(y.tolist())) + tok_list = y.tolist() + # truncate text after end of string token + tokens = tok_list if not tokenizer.eos_id() in y else tok_list[:tok_list.index(tokenizer.eos_id())] + print(tokenizer.decode(tokens)) else: print() tokens_generated = y.size(0) - prompt_length @@ -305,12 +326,13 @@ def callback(x): print(f"Model Size: {model_size:.02f} GB") if write_result: result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " - result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " + result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " result_txt += f"repro: python generate.py " result_txt += f"--quantization {quantization} " if quantization else "" result_txt += f"--checkpoint_path {checkpoint_path} " result_txt += f"--device {device} " result_txt += f"--precision {precision} " + result_txt += f"--kv_cache_quantization " if kv_cache_quantization else "" result_txt += f"--compile " if compile else "" result_txt += f"--compile_prefill " if compile_prefill else "" result_txt += f"--profile {profile} " if profile else "" @@ -337,6 +359,7 @@ def callback(x): parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant') + parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') parser.add_argument('--profile', type=Path, default=None, help='Profile path.') @@ -347,5 +370,5 @@ def callback(x): args = parser.parse_args() main( args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result + args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result ) diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 7204b0e387..58a1709642 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -12,6 +12,7 @@ from torch.nn import functional as F from torchao.utils import find_multiple +# TODO remove suplerfluous arg def prepare_inputs_for_model(inps, max_new_tokens=1): # this is because input from lm-eval is 2d if inps.dim() > 2: @@ -97,6 +98,43 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out + +from torchao.quantization.quant_primitives import quantize_affine, dequantize_affine +from torchao.quantization.utils import quantize_activation_per_token_absmax + +class AffineQuantizedKVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + scale_shape = (max_batch_size, n_heads, max_seq_length, 1) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.int8)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8)) + self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) + self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype)) + + def update(self, input_pos, k_val, v_val): + # quantize current k_val and store it in the cache + q_k_val, k_scale = quantize_activation_per_token_absmax(k_val) + self.k_cache[:, :, input_pos] = q_k_val + self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1) + k_out = self.k_cache*self.k_cache_scale + k_out[:, :, input_pos] = k_val + + q_v_val, v_scale = quantize_activation_per_token_absmax(v_val) + self.v_cache[:, :, input_pos] = q_v_val + self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1) + v_out = self.v_cache*self.v_cache_scale + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + @classmethod + def from_float(cls, kv_cache): + cache_shape = kv_cache.k_cache.shape + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape + scale_dtype = kv_cache.k_cache.dtype + return cls(max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype) + class Transformer(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__()