Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gemlite integration in torchao #1034

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
from torchao.quantization.quant_api import quantize_
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6

try:
import gemlite # noqa: F401

has_gemlite = True
except ModuleNotFoundError:
has_gemlite = False


class TestAffineQuantizedTensorParallel(DTensorTestBase):
"""Basic test case for tensor subclasses"""
Expand Down Expand Up @@ -139,8 +146,29 @@ def test_tp(self, dtype):
return self._test_tp(dtype)


class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel):
COMMON_DTYPES = [torch.float16]

@common_utils.parametrize("dtype", COMMON_DTYPES)
@with_comms
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not has_gemlite, "gemlite not available")
def test_tp_gemlite(self, dtype):
from torchao.quantization import gemlite_uintx_weight_only

for packing_bitwidth in [32, 8]:
for bit_width in [4, 8]:
for group_size in [64, 32, None] if bit_width == 4 else [None]:
api = lambda: gemlite_uintx_weight_only(
group_size, bit_width, packing_bitwidth
)
self.QUANT_METHOD_FN = staticmethod(api)
return self._test_tp(dtype)


common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel)

# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
Expand Down
34 changes: 34 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@
)
from torchao.dtypes.utils import is_device

try:
import gemlite
has_gemlite = True
except ModuleNotFoundError:
has_gemlite = False

logger = logging.getLogger("INFO")

torch.manual_seed(0)
Expand Down Expand Up @@ -870,6 +876,10 @@ def _test_lin_weight_subclass_api_impl(
ref_f = mod(x)
api(mod)

# test get_plain()
if hasattr(mod[0].weight, "tensor_impl"):
mod[0].weight.tensor_impl.get_plain()

test = mod(x)
self.assertGreater(
SQNR(ref_f, test),
Expand Down Expand Up @@ -930,6 +940,30 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater")
@unittest.skipIf(not has_gemlite, "gemlite not available")
def test_gemlite_layout(self, device, dtype):
if dtype!= torch.float16:
self.skipTest(f"gemlite only works for fp16 dtype")
from torchao.quantization import gemlite_uintx_weight_only
if device == "cpu":
self.skipTest(f"gemlite is for cuda, not {device}")
for packing_bitwidth in [32, 8]:
for bit_width in [4,8]:
for group_size in [64, 32, None] if bit_width ==4 else [None]:
api = lambda mod: quantize_(mod, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth))
for test_shape in [[1, 1024, 512],[16, 256, 1024], [128, 256, 1024]]:
print(packing_bitwidth, bit_width, group_size, test_shape, dtype)
self._test_lin_weight_subclass_api_impl(
api,
device,
15,
test_shape=test_shape,
test_dtype=dtype,
)


@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
Expand Down
15 changes: 15 additions & 0 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured

# gemlite benchmarks
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt --batch_size 32

# 2:4 sparse model
export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt
Expand Down
42 changes: 39 additions & 3 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def decode_n_tokens(
)
next_token, next_prob = next_token.clone(), next_prob.clone()
input_pos += 1
new_tokens.append(next_token)
# in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob)
cur_token = next_token
Expand Down Expand Up @@ -368,6 +369,7 @@ def ffn_or_attn_only(mod, fqn):
int8_weight_only,
quantize_,
uintx_weight_only,
gemlite_uintx_weight_only,
)

from torchao.quantization.granularity import PerRow, PerTensor
Expand All @@ -377,6 +379,39 @@ def ffn_or_attn_only(mod, fqn):
from torchao.prototype.spinquant import apply_spinquant

apply_spinquant(model)
if "gemlite" in quantization:
import os, pwd
import gemlite
from gemlite.core import GemLiteLinearTriton, set_autotune
_quant_args = quantization.split("-")
bit_width = int(_quant_args[-2])
group_size = None if _quant_args[-1] == 'None' else int(_quant_args[-1])
try:
packing_bitwidth = int(_quant_args[-3])
except:
# if only 2 inputs found, use default value
packing_bitwidth = 32

quantize_(model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth))

# try to load gemlite kernel config
try:
GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
print(f"loaded gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
except:
print(f"unable to load gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")

print("running gemlite warmup")
generate(
model,
encode_tokens(tokenizer, prompt, bos=True, device=device),
max_new_tokens,
batch_size,
interactive=False,
temperature=temperature,
top_k=top_k,
)
GemLiteLinearTriton.cache_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down Expand Up @@ -959,7 +994,7 @@ def callback(x):

parser = argparse.ArgumentParser(description="Your CLI description.")
parser.add_argument(
"--prefill_size", type=int, default=0, help="Whether to run in ttft mode"
"--prefill_size", type=int, default=None, help="Whether to run in ttft mode"
)
parser.add_argument(
"--prompt", type=str, default="Hello, my name is", help="Input prompt."
Expand Down Expand Up @@ -993,7 +1028,7 @@ def callback(x):
help=(
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
+ "autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
+ "embed-int8wo, marlin_qqq"
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>"
),
)
parser.add_argument(
Expand Down Expand Up @@ -1053,6 +1088,7 @@ def callback(x):
)

args = parser.parse_args()
print(args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove

main(
args.prefill_size,
args.prompt,
Expand Down
5 changes: 4 additions & 1 deletion torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
dtype = self.output.weight.dtype
dtype = None
# module swaps can cause issues without this
if hasattr(self.output, "weight"):
dtype = self.output.weight.dtype
# For quantized layers, dtype is encoded in scales
if hasattr(self.output, "scales"):
dtype = self.output.scales.dtype
Expand Down
15 changes: 13 additions & 2 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def from_hp_to_intx(
else input_float.dtype
)
device = input_float.device
from torchao.dtypes.uintx import TensorCoreTiledLayout

data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(
input_float,
nbits=nbits,
Expand All @@ -233,7 +235,15 @@ def from_hp_to_intx(
compute_dtype=compute_dtype,
device=device,
verbose=False,
raw_output=False,
raw_output=not isinstance(
_layout, (TensorCoreTiledLayout, PlainLayout)
),
# raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint)
# note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether
# zero is preserved.
# TODO uncouple preserve_zero and conversion of zero_point to TensorCoreTiledLayout version
# TODO move the conversion of zero_point out of quant_primitives and into TensorCoreTiledLayout.from_plain
# TODO change PlainLayout to use raw_output.
)
data = data.to(target_dtype)
else:
Expand All @@ -251,7 +261,8 @@ def from_hp_to_intx(
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
if zero_point_domain is None:
# TODO should probably consolidate ZeroPointDomain.NONE and None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
Expand Down
8 changes: 8 additions & 0 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
_linear_int8_act_int8_weight_block_sparse_check,
_linear_int8_act_int8_weight_block_sparse_impl,
)
from torchao.dtypes.uintx.gemlite_layout import (
_linear_fp_act_int4_weight_gemlite_check,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: fp16? seems like gemlite only works with fp16 right now, but can be a follow up PR

_linear_fp_act_int4_weight_gemlite_impl,
)
from torchao.dtypes.uintx.marlin_qqq_tensor import (
_linear_int8_act_int4_weight_marlin_qqq_check,
_linear_int8_act_int4_weight_marlin_qqq_impl,
Expand Down Expand Up @@ -135,6 +139,10 @@ def _register_aqt_quantized_linear_dispatches():
_linear_int8_act_int4_weight_marlin_qqq_check,
_linear_int8_act_int4_weight_marlin_qqq_impl,
),
(
_linear_fp_act_int4_weight_gemlite_check,
_linear_fp_act_int4_weight_gemlite_impl,
),
]:
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)

Expand Down
Loading
Loading