From 923f65722d2b2897c3e0cf6ca51e884b555fe162 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 8 Oct 2024 16:14:33 -0700 Subject: [PATCH] custom op working with compile now Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/_models/llama/benchmark_results.txt | 21 +-- torchao/_models/llama/benchmarks.sh | 124 +++++++++--------- torchao/_models/llama/generate.py | 11 ++ .../quantization/prototype/gemlite/core.py | 3 +- .../gemm_A16fWnO16f_int32packing.py | 58 +++++--- .../gemv_A16fWnO16f_int32packing.py | 37 +++++- 6 files changed, 157 insertions(+), 97 deletions(-) diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index 23c2803551..4e79a81a81 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -52,16 +52,19 @@ OTHER BENCHMARKS 20240910110958, tok/s=223.95, mem/s= 682.88 GB/s, peak_mem= 5.59 GB, model_size= 3.05 GB quant: sparse-marlin, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 bs1 -20241007221134, tok/s= 13.93, mem/s= 184.03 GB/s, peak_mem=13.64 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 -20241007223402, tok/s= 0.32, mem/s= 1.18 GB/s, peak_mem= 5.55 GB, model_size= 3.72 GB quant: gemlite-4-64, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization gemlite-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 -20241008092353, tok/s= 15.35, mem/s= 57.35 GB/s, peak_mem=15.56 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20241008151940, tok/s= 94.38, mem/s=1416.56 GB/s, peak_mem=16.46 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20241008152137, tok/s=181.84, mem/s= 767.71 GB/s, peak_mem= 6.57 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20241008152538, tok/s= 49.68, mem/s= 211.10 GB/s, peak_mem= 7.40 GB, model_size= 4.25 GB quant: gemlite-4-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization gemlite-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20241008153006, tok/s= 52.19, mem/s= 221.78 GB/s, peak_mem= 7.65 GB, model_size= 4.25 GB quant: gemlite-4-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization gemlite-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 bs2 -20241007221256, tok/s= 20.04, mem/s= 264.80 GB/s, peak_mem=13.78 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 1 --max_new_tokens 200 --batch_size 2 --top_k 200 --temperature 0.8 -20241007223928, tok/s= 0.92, mem/s= 3.43 GB/s, peak_mem= 5.57 GB, model_size= 3.72 GB quant: gemlite-4-64, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization gemlite-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --num_samples 1 --max_new_tokens 200 --batch_size 2 --top_k 200 --temperature 0.8 -20241008092519, tok/s= 15.06, mem/s= 56.26 GB/s, peak_mem=15.58 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 1 --max_new_tokens 200 --batch_size 2 --top_k 200 --temperature 0.8 +20241008153347, tok/s= 84.89, mem/s=1274.15 GB/s, peak_mem=16.81 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 2 --top_k 200 --temperature 0.8 +20241008153609, tok/s=173.71, mem/s= 733.37 GB/s, peak_mem= 6.92 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 2 --top_k 200 --temperature 0.8 +20241008154149, tok/s= 49.57, mem/s= 211.96 GB/s, peak_mem= 7.75 GB, model_size= 4.28 GB quant: gemlite-4-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization gemlite-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 2 --top_k 200 --temperature 0.8 +20241008154651, tok/s= 52.04, mem/s= 222.53 GB/s, peak_mem= 7.67 GB, model_size= 4.28 GB quant: gemlite-4-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization gemlite-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 2 --top_k 200 --temperature 0.8 bs4 -20241007221421, tok/s= 19.03, mem/s= 251.42 GB/s, peak_mem=14.06 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 1 --max_new_tokens 200 --batch_size 4 --top_k 200 --temperature 0.8 -20241007224456, tok/s= 0.91, mem/s= 3.38 GB/s, peak_mem= 5.59 GB, model_size= 3.72 GB quant: gemlite-4-64, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization gemlite-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --num_samples 1 --max_new_tokens 200 --batch_size 4 --top_k 200 --temperature 0.8 -20241008092656, tok/s= 12.32, mem/s= 46.04 GB/s, peak_mem=15.60 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 1 --max_new_tokens 200 --batch_size 4 --top_k 200 --temperature 0.8 +20241008155034, tok/s= 83.37, mem/s=1251.36 GB/s, peak_mem=16.97 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 4 --top_k 200 --temperature 0.8 +20241008155257, tok/s=141.60, mem/s= 597.82 GB/s, peak_mem= 6.95 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 4 --top_k 200 --temperature 0.8 +20241008155928, tok/s= 49.45, mem/s= 214.18 GB/s, peak_mem= 7.81 GB, model_size= 4.33 GB quant: gemlite-4-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization gemlite-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 4 --top_k 200 --temperature 0.8 +20241008160515, tok/s= 51.74, mem/s= 224.09 GB/s, peak_mem= 7.79 GB, model_size= 4.33 GB quant: gemlite-4-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization gemlite-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 1 --max_new_tokens 200 --batch_size 4 --top_k 200 --temperature 0.8 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index b523143d8d..15fc7293aa 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -2,73 +2,73 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder # README BENCHMARKS export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-4-64 --num_samples 1 --write_result benchmark_results.txt --compile +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --num_samples 1 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --num_samples 1 --batch_size 2 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --num_samples 1 --batch_size 4 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --num_samples 1 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --batch_size 2 --num_samples 1 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --batch_size 4 --num_samples 1 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-4-64 --num_samples 1 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-4-64 --num_samples 1 --batch_size 2 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-4-64 --num_samples 1 --batch_size 4 --write_result benchmark_results.txt +export MODEL_REPO=meta-llama/Meta-Llama-3-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt +# OTHER BENCHMARKS -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt - - - -# export MODEL_REPO=meta-llama/Meta-Llama-3-8B -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt +# kv cache quantization +export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization --linear_causal_mask +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 16384 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 16384 --kv_cache_quantization +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 16384 --kv_cache_quantization --linear_causal_mask +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 32768 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 32768 --kv_cache_quantization +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 32768 --kv_cache_quantization --linear_causal_mask +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 65536 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 65536 --kv_cache_quantization +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 65536 --kv_cache_quantization --linear_causal_mask +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization --linear_causal_mask +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt -# # OTHER BENCHMARKS +export MODEL_REPO=meta-llama/Meta-Llama-3-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt -# # kv cache quantization -# export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization --linear_causal_mask -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 16384 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 16384 --kv_cache_quantization -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 16384 --kv_cache_quantization --linear_causal_mask -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 32768 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 32768 --kv_cache_quantization -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 32768 --kv_cache_quantization --linear_causal_mask -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 65536 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 65536 --kv_cache_quantization -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 65536 --kv_cache_quantization --linear_causal_mask -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization --linear_causal_mask +export MODEL_REPO=meta-llama/Meta-Llama-3-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --compile --num_samples 1 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --write_result benchmark_results.txt --compile --num_samples 1 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-4-64 --write_result benchmark_results.txt --compile --num_samples 1 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization gemlite-4-64 --write_result benchmark_results.txt --compile --num_samples 1 -# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 2 --write_result benchmark_results.txt --compile --num_samples 1 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --batch_size 2 --write_result benchmark_results.txt --compile --num_samples 1 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-4-64 --batch_size 2 --write_result benchmark_results.txt --compile --num_samples 1 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization gemlite-4-64 --batch_size 2 --write_result benchmark_results.txt --compile --num_samples 1 -# export MODEL_REPO=meta-llama/Meta-Llama-3-8B -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 4 --write_result benchmark_results.txt --compile --num_samples 1 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --batch_size 4 --write_result benchmark_results.txt --compile --num_samples 1 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-4-64 --batch_size 4 --write_result benchmark_results.txt --compile --num_samples 1 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization gemlite-4-64 --batch_size 4 --write_result benchmark_results.txt --compile --num_samples 1 diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 3cdbed2898..01bf656430 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -255,6 +255,17 @@ def replace_fn(mod): _replace_with_custom_fn_if_matches_filter(model, replace_fn, _is_linear) + generate( + model, + encode_tokens(tokenizer, prompt, bos=True, device=device), + max_new_tokens, + batch_size, + interactive=False, + temperature=temperature, + top_k=top_k, + ) + + if "int8wo" in quantization: quantize_(model, int8_weight_only()) if "int8dq" in quantization: diff --git a/torchao/quantization/prototype/gemlite/core.py b/torchao/quantization/prototype/gemlite/core.py index 073f31e33f..f861ccfa08 100644 --- a/torchao/quantization/prototype/gemlite/core.py +++ b/torchao/quantization/prototype/gemlite/core.py @@ -20,7 +20,6 @@ class DType(Enum): INT32 = "INT32" FP16D8 = "FP16D8i" # dynamic quantization - ################################################################################################################################### # CUDA backend ################################################################################################################################### @@ -107,7 +106,7 @@ def __init__( acc_dtype = ( DType.FP16 if (self.compute_dtype == torch.float16) else DType.FP32 ) - self.acc_dtype = tl.float16 if (acc_dtype == DType.FP16) else tl.float32 + self.acc_dtype = torch.float16 if (acc_dtype == DType.FP16) else torch.float32 self.dtype = self.output_dtype diff --git a/torchao/quantization/prototype/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py b/torchao/quantization/prototype/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py index 9fcfbcfb95..aa95ad8334 100755 --- a/torchao/quantization/prototype/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py +++ b/torchao/quantization/prototype/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py @@ -4,7 +4,7 @@ import triton import triton.language as tl - +from torch.library import custom_op, register_fake # code based https://github.com/fpgaminer/GPTQ-triton def kernel_config_pruner(configs, nargs, **kwargs): @@ -83,11 +83,11 @@ def get_gemm_config(): @triton.autotune( configs=get_gemm_config(), key=["M", "N", "K", "group_size", "W_nbits"], - # prune_configs_by={ - # 'early_config_prune': kernel_config_pruner, - # }, - # warmup=200, - # rep=50, #20 for faster tuning + prune_configs_by={ + 'early_config_prune': kernel_config_pruner, + }, + warmup=200, + rep=50, #20 for faster tuning ) @triton.jit def gemm_A16fWnO16f_int32packing_kernel( @@ -194,22 +194,21 @@ def gemm_A16fWnO16f_int32packing_kernel( tl.store(c_ptrs, acc, mask=(offs_am[:, None] < M) & (offs_bn[None, :] < N)) - +@custom_op("torchao::gemm_A16fWnO16f", mutates_args=(), device_types="cuda") def gemm_A16fWnO16f_int32packing_forward( - x, - W_q, - scales, - zeros, - W_nbits, - group_size, - unpack_mask, - elements_per_sample, - acc_dtype=tl.float16, -): + x: torch.Tensor, + W_q: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + W_nbits: int, + group_size: int, + unpack_mask: int, + elements_per_sample: int, + acc_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: output = torch.empty( (x.shape[0], W_q.shape[1]), device=W_q.device, dtype=scales.dtype ) - # assert x.shape[1] == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" grid = lambda META: ( @@ -217,6 +216,8 @@ def gemm_A16fWnO16f_int32packing_forward( * triton.cdiv(W_q.shape[1], META["BLOCK_SIZE_N"]), ) + triton_acc_dtype = tl.float16 if acc_dtype == torch.float16 else tl.float32 + gemm_A16fWnO16f_int32packing_kernel[grid]( x, W_q, @@ -237,15 +238,32 @@ def gemm_A16fWnO16f_int32packing_forward( output.stride(0), output.stride(1), scales.stride(0), - acc_dtype, + triton_acc_dtype, ) return output +@register_fake("torchao::gemm_A16fWnO16f") +def _( + x: torch.Tensor, + W_q: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + W_nbits: int, + group_size: int, + unpack_mask:int, + elements_per_sample: int, + acc_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + M, K = x.shape + K_samples, N = W_q.shape + return torch.empty((M, N,), device=x.device, dtype=scales.dtype) + class gemm_A16fWnO16f_int32packing: kernel = gemm_A16fWnO16f_int32packing_kernel - forward = gemm_A16fWnO16f_int32packing_forward + forward = torch.ops.torchao.gemm_A16fWnO16f + # forward = gemm_A16fWnO16f_int32packing_forward matmul_type = "GEMM" diff --git a/torchao/quantization/prototype/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py b/torchao/quantization/prototype/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py index 162b548785..3856a57c8c 100755 --- a/torchao/quantization/prototype/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py +++ b/torchao/quantization/prototype/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py @@ -3,6 +3,7 @@ import torch, math import triton import triton.language as tl +from torch.library import custom_op, register_fake def init_to_zero(name): return lambda nargs: nargs[name].zero_() @@ -112,13 +113,24 @@ def gemv_A16fWnO16f_int32packing_kernel( #Output: tl.atomic_add only supports 1D fp16 arrays, bfp16 would crash tl.atomic_add(c_ptr + offs_bn + pid_m*N, acc, sem="relaxed", scope="cta") #Force cta scope - -def gemv_A16fWnO16f_int32packing_forward(x, W_q, scales, zeros, W_nbits, group_size, unpack_mask, elements_per_sample, acc_dtype=tl.float16): +@custom_op("torchao::gemv_A16fWnO16f", mutates_args=(), device_types="cuda") +def gemv_A16fWnO16f_int32packing_forward( + x: torch.Tensor, + W_q: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + W_nbits: int, + group_size: int, + unpack_mask: int, + elements_per_sample: int, + acc_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: #assert x.shape[1] == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" M, K, N = x.shape[0], x.shape[1], W_q.shape[1] output = torch.empty((M, N), device=W_q.device, dtype=scales.dtype) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(K, meta['BLOCK_SIZE_K']), triton.cdiv(N, meta['BLOCK_SIZE_N'])) + triton_acc_dtype = tl.float16 if acc_dtype == torch.float16 else tl.float32 gemv_A16fWnO16f_int32packing_kernel[grid]( x, W_q, output, @@ -129,14 +141,31 @@ def gemv_A16fWnO16f_int32packing_forward(x, W_q, scales, zeros, W_nbits, group_s W_q.stride(0), W_q.stride(1), output.stride(0), output.stride(1), scales.stride(0), - acc_dtype + triton_acc_dtype, ) return output +@register_fake("torchao::gemv_A16fWnO16f") +def _( + x: torch.Tensor, + W_q: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + W_nbits: int, + group_size: int, + unpack_mask:int, + elements_per_sample: int, + acc_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + M, K = x.shape + K_samples, N = W_q.shape + return torch.empty((M, N,), device=x.device, dtype=scales.dtype) + class gemv_A16fWnO16f_int32packing: kernel = gemv_A16fWnO16f_int32packing_kernel - forward = gemv_A16fWnO16f_int32packing_forward + forward = torch.ops.torchao.gemv_A16fWnO16f + # forward = gemv_A16fWnO16f_int32packing_forward matmul_type = "GEMV" __all__ = ["gemv_A16fWnO16f_int32packing"]