diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 7dd20636c892f..f6f123fcaa82b 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -159,5 +159,19 @@ def test_compressed_tensors_fp8(vllm_runner): def test_compressed_tensors_kv_cache(vllm_runner): model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: + output = llm.generate_greedy("Hello world!", max_tokens=20) + assert output + + +def test_compressed_tensors_actorder_weight(vllm_runner): + model_path = "kylesayrs/TinyLlama-1.1B-Chat-v1.0-actorder-weight-e2e" + with vllm_runner(model_path) as llm: + output = llm.generate_greedy("Hello world!", max_tokens=20) + assert output + + +def test_compressed_tensors_actorder_group(vllm_runner): + model_path = "kylesayrs/TinyLlama-1.1B-Chat-v1.0-actorder-group-e2e" + with vllm_runner(model_path) as llm: output = llm.generate_greedy("Hello world!", max_tokens=20) assert output \ No newline at end of file diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index cbe30305c14f6..b9b9cdd30d853 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -15,6 +15,8 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main +compressed-tensors, kylesayrs/TinyLlama-1.1B-Chat-v1.0-actorder-weight-e2e, main +compressed-tensors, kylesayrs/TinyLlama-1.1B-Chat-v1.0-actorder-group-e2e, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 7ca8eecb9283e..8dc5a50169561 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -3,12 +3,13 @@ import torch from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) + marlin_permute_scales, marlin_sort_g_idx, replace_tensor, + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -22,6 +23,8 @@ } WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) +logger = init_logger(__name__) + class CompressedTensorsWNA16(CompressedTensorsScheme): @@ -119,9 +122,15 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, dtype=torch.int64), weight_loader=weight_loader) + # group index (for activation reordering) + weight_g_idx = BasevLLMParameter(data=torch.full( + (input_size_per_partition, ), -1, dtype=torch.int32), + weight_loader=weight_loader) + layer.register_parameter("weight_packed", weight) layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) + layer.register_parameter("weight_g_idx", weight_g_idx) layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition @@ -137,9 +146,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.workspace = marlin_make_workspace( layer.output_size_per_partition, device) - # Act-order not supported in compressed-tensors yet, so set to empty. - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + # Handle sorting for activation reordering if needed. + has_g_idx = -1 not in layer.weight_g_idx + if has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + replace_tensor(layer, "weight_g_idx", g_idx) + else: + layer.weight_g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) # No zero-point layer.weight_zp = marlin_make_empty_g_idx(device) @@ -161,7 +176,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Permute scales from compressed-tensors format to marlin format. marlin_scales = marlin_permute_scales( layer.weight_scale, - size_k=layer.input_size_per_partition, + size_k=(layer.input_size + if has_g_idx else layer.input_size_per_partition), size_n=layer.output_size_per_partition, group_size=layer.group_size) replace_tensor(layer, "weight_scale", marlin_scales) @@ -174,7 +190,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, weight=layer.weight_packed, weight_scale=layer.weight_scale, weight_zp=layer.weight_zp, - g_idx=layer.g_idx, + g_idx=layer.weight_g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, wtype=self.quant_type, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 7912cbde5721f..3dce1e9e87cf3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum): TOKEN = "token" +class ActivationOrdering(str, Enum): + """ + Enum storing strategies for activation ordering + + Group: reorder groups and weight\n + Weight: only reorder weight, not groups. Slightly lower latency and + accuracy compared to group actorder\n + """ + + GROUP = "group" + WEIGHT = "weight" + + class QuantizationArgs(BaseModel): """ User facing arguments used to define a quantization config @@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel): observed with every sample. Defaults to False for static quantization. Note that enabling dynamic quantization will change the default observer to a memoryless one + :param actorder: whether to apply group quantization in decreasing order of + activation. Defaults to None for arbitrary ordering """ num_bits: int = 8 @@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel): strategy: Optional[QuantizationStrategy] = None block_structure: Optional[str] = None dynamic: bool = False + actorder: Optional[ActivationOrdering] = None observer: str = Field( default="minmax", description=("The class to use to compute the quantization param - " diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0ec68ac5b0f21..2ad6df24dd1d9 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -129,16 +129,16 @@ def marlin_make_workspace(output_size_per_partition: int, requires_grad=False) -def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: - return (not act_order) or (act_order and not is_row_parallel) +def marlin_is_k_full(has_g_idx: bool, is_row_parallel: bool) -> bool: + return (not has_g_idx) or (not is_row_parallel) -def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, +def marlin_repeat_scales_on_all_ranks(has_g_idx: bool, group_size: int, is_row_parallel: bool) -> bool: - # Need to repeat scales on every rank if act_ordering or + # Need to repeat scales on every rank if actorder or # channelwise and RowParallelLinear is_channelwise = group_size == -1 - return act_order or (is_channelwise and is_row_parallel) + return has_g_idx or (is_channelwise and is_row_parallel) def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: