-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ Misc ] Support Act Order in Compressed Tensors #6358
Changes from 146 commits
1dfc42d
aa4a9f5
27f9a03
de7a064
ec6a833
966f7be
d391f44
db075c3
695dc05
75c8a11
525cf08
81f028e
a8fbe89
cc843ad
b260c90
c8e97b1
e58063d
383e471
865b743
9c24525
8b5ac5a
0e46e4b
a47a251
c6be536
0441171
d404f00
6569323
aa56475
d94d07e
54308d7
173b93b
afa1ee1
5ffe0e4
4c0e565
ee58d33
282a038
a0fd035
ba1116b
c1d4375
a12bfd5
6aad8f6
4fc0177
1d99867
ccee126
0969c67
b2eeb84
9316f92
4ff23c8
08a8e4e
4238ac9
d1c7517
94d6b35
d48ba9d
0dd2c6a
29f40f5
ad17c88
fd7d825
697edfa
e30bd57
382d230
ba4c7b3
0916182
de0242f
9fe4fce
c044a86
a5f0aee
d3299f8
bcfcd38
763ab2c
950de45
eb2fdfa
2f49425
93812eb
d4b25cf
48b220e
f1d8ee4
cfe27be
72b9368
73ae598
f854c54
13d4e93
4e09688
db694e0
4b2dba2
9d8d12f
54cf4f2
7abc2b1
ed178d4
03b11b2
e2a5e7a
6f62ada
933bec3
fe6ae88
8285ef6
fcc8925
c0b5d13
f6910a5
84ed30f
62368af
b618961
f2755f2
cd392f5
b092079
5cbed16
054e2db
7e0b0ec
bddf9d3
ad43c4e
0aa9181
777e74b
77988d3
39ed988
7d2fff8
2e74b0b
a845475
2e7bf61
18596e2
48aae94
b34ca83
881afd7
3cd8b55
02637af
81f41ed
536fdde
1d10244
4c96377
4ca4a08
1080488
8531380
0ddd524
052cc93
6211660
a0d0251
f187922
434b471
04ed5d7
07ad850
d2a923a
22de619
fb8ffb2
3bb7294
0e396fc
14495ba
2f46596
ef08596
22e579e
cc2c9ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,12 +3,13 @@ | |
import torch | ||
|
||
from vllm import _custom_ops as ops | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( | ||
CompressedTensorsScheme) | ||
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( | ||
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, | ||
marlin_permute_scales, replace_tensor, verify_marlin_supported, | ||
verify_marlin_supports_shape) | ||
apply_gptq_marlin_linear, marlin_is_k_full, marlin_make_empty_g_idx, | ||
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, | ||
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) | ||
from vllm.model_executor.parameter import (BasevLLMParameter, | ||
ChannelQuantScaleParameter, | ||
GroupQuantScaleParameter, | ||
|
@@ -22,13 +23,16 @@ | |
} | ||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class CompressedTensorsWNA16(CompressedTensorsScheme): | ||
|
||
def __init__(self, | ||
strategy: str, | ||
num_bits: int, | ||
group_size: Optional[int] = None): | ||
group_size: Optional[int] = None, | ||
actorder: bool = False): | ||
|
||
self.pack_factor = 32 // num_bits | ||
self.strategy = strategy | ||
|
@@ -46,6 +50,15 @@ def __init__(self, | |
|
||
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] | ||
|
||
if actorder and self.group_size == -1: | ||
# In this case, actorder == True is the same as actorder == False | ||
# (since we have only one group per output channel) | ||
logger.warning( | ||
"Model must be quantized with group_size > 0 in order to use " | ||
"activation ordering") | ||
actorder = False | ||
self.actorder = actorder | ||
|
||
# Verify supported on platform. | ||
verify_marlin_supported(quant_type=self.quant_type, | ||
group_size=self.group_size) | ||
|
@@ -62,6 +75,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, | |
**kwargs): | ||
|
||
output_size_per_partition = sum(output_partition_sizes) | ||
is_row_parallel = input_size != input_size_per_partition | ||
|
||
# If group_size is -1, we are in channelwise case. | ||
channelwise = (self.group_size == -1) | ||
|
@@ -119,14 +133,21 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, | |
dtype=torch.int64), | ||
weight_loader=weight_loader) | ||
|
||
# G_IDX (for activation reordering) | ||
g_idx = BasevLLMParameter(data=torch.empty(input_size_per_partition, | ||
dtype=torch.int32), | ||
weight_loader=weight_loader) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it okay to make this parameter in every case? What about older checkpoints that don't have this parameter? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gptq_marlin_gemm supports passing an empty tensor for g_idx, I'd prefer to that or a nullptr to avoid excess memory usage There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think my question was worded weirdly, sorry. I am just concerned about the weight loader trying to find this parameter in the checkpoint, and it not being present. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I regression tested using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd just update to only create the parameter if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because the g_idx passed to the kernel is conditional on the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If self.actorder is True, it'll use the created parameter. Otherwise, it'll create an empty one. So I dont think you need to initialize it here if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that seems to be the case from this else-case later - so no need to make the parameter layer.weight_g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) |
||
|
||
layer.register_parameter("weight_packed", weight) | ||
layer.register_parameter("weight_scale", weight_scale) | ||
layer.register_parameter("weight_shape", weight_shape) | ||
layer.register_parameter("weight_g_idx", g_idx) | ||
|
||
layer.input_size_per_partition = input_size_per_partition | ||
layer.output_size_per_partition = output_size_per_partition | ||
layer.input_size = input_size | ||
layer.group_size = group_size | ||
layer.is_k_full = marlin_is_k_full(self.actorder, is_row_parallel) | ||
|
||
# Checkpoints are serialized in compressed-tensors format, which is | ||
# different from marlin format. Handle repacking here. | ||
|
@@ -137,9 +158,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
layer.workspace = marlin_make_workspace( | ||
layer.output_size_per_partition, device) | ||
|
||
# Act-order not supported in compressed-tensors yet, so set to empty. | ||
layer.g_idx = marlin_make_empty_g_idx(device) | ||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) | ||
# Handle sorting for activation reordering if needed. | ||
if self.actorder: | ||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx) | ||
layer.g_idx_sort_indices = g_idx_sort_indices | ||
replace_tensor(layer, "weight_g_idx", g_idx) | ||
else: | ||
layer.weight_g_idx = marlin_make_empty_g_idx(device) | ||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) | ||
|
||
# No zero-point | ||
layer.weight_zp = marlin_make_empty_g_idx(device) | ||
|
@@ -161,7 +187,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
# Permute scales from compressed-tensors format to marlin format. | ||
marlin_scales = marlin_permute_scales( | ||
layer.weight_scale, | ||
size_k=layer.input_size_per_partition, | ||
size_k=(layer.input_size | ||
if self.actorder else layer.input_size_per_partition), | ||
size_n=layer.output_size_per_partition, | ||
group_size=layer.group_size) | ||
replace_tensor(layer, "weight_scale", marlin_scales) | ||
|
@@ -174,7 +201,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, | |
weight=layer.weight_packed, | ||
weight_scale=layer.weight_scale, | ||
weight_zp=layer.weight_zp, | ||
g_idx=layer.g_idx, | ||
g_idx=layer.weight_g_idx, | ||
g_idx_sort_indices=layer.g_idx_sort_indices, | ||
workspace=layer.workspace, | ||
wtype=self.quant_type, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test case for this case?