Skip to content

Commit

Permalink
[Misc][Kernel]: Add GPTQAllSpark Quantization
Browse files Browse the repository at this point in the history
Signed-off-by: wyj371990 <[email protected]>
  • Loading branch information
wyajieha committed Feb 23, 2025
1 parent eb24dc4 commit c8eae49
Show file tree
Hide file tree
Showing 15 changed files with 2,046 additions and 8 deletions.
16 changes: 16 additions & 0 deletions CMakeLists.txt
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()

# Only build AllSpark kernels if we are building for at least some compatible archs.
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
if (ALLSPARK_ARCHS)
set(ALLSPARK_SRCS
"csrc/quantization/gptq_allspark/allspark_reorder.cu"
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
set_gencode_flags_for_srcs(
SRCS "${ALLSPARK_SRCS}"
CUDA_ARCHS "${ALLSPARK_ARCHS}")
list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}")
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
else()
message(STATUS "Not building AllSpark kernels as no compatible archs found"
" in CUDA target architectures")
endif()

# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
Expand Down
49 changes: 47 additions & 2 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
Expand All @@ -18,12 +20,12 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, gptq_quantize_weights, sort_weights)
gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
from vllm.scalar_type import ScalarType
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
Expand Down Expand Up @@ -81,6 +83,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,
GPTQ_MARLIN_24_MAX_PARALLEL)
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)

# AllSpark W8A16 quant
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
and group_size == -1 and not act_order and is_k_full)
if as_supported_case:
properties = torch.cuda.get_device_properties(b.device.index)
sm_count = properties.multi_processor_count
sm_version = properties.major * 10 + properties.minor

supported_arch = (sm_version >= 80 and sm_version < 90)
as_supported_case = as_supported_case and supported_arch
if supported_arch:
has_zp = False
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
has_zp)
qw = qw.to(torch.uint8)

qw_reorder, s_reorder, zp_reorder = \
ops.allspark_repack_weight(
qw, s, zp, has_zp)
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD

globals = {
# Gen params
"quant_type": quant_type,
Expand Down Expand Up @@ -109,10 +132,21 @@ def bench_run(results: List[benchmark.Measurement], model: str,
# GPTQ params
"q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices,
# AllSpark W8A16 params
"qw_reorder": qw_reorder if as_supported_case else None,
"s_reorder": s_reorder if as_supported_case else None,
"zp_reorder": zp_reorder if as_supported_case else None,
"sm_count": sm_count if as_supported_case else None,
"sm_version": sm_version if as_supported_case else None,
"CUBLAS_M_THRESHOLD":
CUBLAS_M_THRESHOLD if as_supported_case else None,
"weight_name_pattern":
f'model.layers.k{size_k}.m{size_m}.n{size_n}.qweight',
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack,
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
}

min_run_time = 1
Expand Down Expand Up @@ -172,6 +206,17 @@ def bench_run(results: List[benchmark.Measurement], model: str,
description="gptq_marlin_repack",
).blocked_autorange(min_run_time=min_run_time))

if as_supported_case:
results.append(
benchmark.Timer(
stmt=
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True, weight_name_pattern)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="allspark_w8a16_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time))


def main(args):
print("Benchmarking models:")
Expand Down
Loading

0 comments on commit c8eae49

Please sign in to comment.