forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Efficient Inference Kernel for SpQR (huggingface#34976)
* Resolve vptq conflict * Rename spqr package to spqr_quant * Get rid of aqlm mention * Start working on tests * Resolve ruff code checks * Ruff format * Isort * Test updates * Add gpu tag * Rename to modules_to_not_convert * Config update * Docs and config update * Docs and config update * Update to update_torch_dtype * spqr config parameter validation * Ruff update * Apply ruff fixes * Test fixes * Ruff update * Mark tests as @slow again; Ruff; Docstring update * Ruff * Remove absolute path * Resolve typo * Remove redundandt log * Check accelerate/spqr availability * Ruff fix * Check if the config contains proper shapes * Ruff test * Documentation update * overview update * Ruff checks * Ruff code quality * Make style * Update docs/source/en/quantization/spqr.md Co-authored-by: Steven Liu <[email protected]> * Update spqr.md * Enable gptqmodel (huggingface#35012) * gptqmodel Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * update readme Signed-off-by: jiqing-feng <[email protected]> * gptqmodel need use checkpoint_format (#1) * gptqmodel need use checkpoint_format * fix quantize * Update quantization_config.py * Update quantization_config.py * Update quantization_config.py --------- Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: Qubitium-ModelCloud <[email protected]> * Revert quantizer_gptq.py (#2) * revert quantizer_gptq.py change * pass **kwargs * limit gptqmodel and optimum version Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * fix warning Signed-off-by: jiqing-feng <[email protected]> * fix version check Signed-off-by: jiqing-feng <[email protected]> * revert unrelated changes Signed-off-by: jiqing-feng <[email protected]> * enable gptqmodel tests Signed-off-by: jiqing-feng <[email protected]> * fix requires gptq Signed-off-by: jiqing-feng <[email protected]> * Fix Transformer compat (#3) * revert quantizer_gptq.py change * pass **kwargs * add meta info * cleanup * cleanup * Update quantization_config.py * hf_select_quant_linear pass checkpoint_format and meta * fix GPTQTestCUDA * Update test_gptq.py * gptqmodel.hf_select_quant_linear() now does not select ExllamaV2 * cleanup * add backend * cleanup * cleanup * no need check exllama version * Update quantization_config.py * lower checkpoint_format and backend * check none * cleanup * Update quantization_config.py * fix self.use_exllama == False * spell * fix unittest * fix unittest --------- Co-authored-by: LRL <[email protected]> Co-authored-by: Qubitium-ModelCloud <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * fix format again Signed-off-by: jiqing-feng <[email protected]> * update gptqmodel version (#6) * update gptqmodel version * update gptqmodel version * fix unit test (#5) * update gptqmodel version * update gptqmodel version * "not self.use_exllama" is not equivalent to "self.use_exllama==False" * fix unittest * update gptqmodel version * backend is loading_attibutes (#7) * fix format and tests Signed-off-by: jiqing-feng <[email protected]> * fix memory check Signed-off-by: jiqing-feng <[email protected]> * fix device mismatch Signed-off-by: jiqing-feng <[email protected]> * fix result check Signed-off-by: jiqing-feng <[email protected]> * Update src/transformers/quantizers/quantizer_gptq.py Co-authored-by: Marc Sun <[email protected]> * Update src/transformers/quantizers/quantizer_gptq.py Co-authored-by: Marc Sun <[email protected]> * Update src/transformers/quantizers/quantizer_gptq.py Co-authored-by: Marc Sun <[email protected]> * update tests Signed-off-by: jiqing-feng <[email protected]> * review: update docs (#10) * review: update docs (#12) * review: update docs * fix typo * update tests for gptqmodel Signed-off-by: jiqing-feng <[email protected]> * update document (#9) * update overview.md * cleanup * Update overview.md * Update overview.md * Update overview.md * update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md * Update gptq.md --------- Co-authored-by: Qubitium-ModelCloud <[email protected]> * typo * doc note for asymmetric quant * typo with apple silicon(e) * typo for marlin * column name revert: review * doc rocm support * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/quantization/gptq.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/quantization/overview.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/quantization/overview.md Co-authored-by: Steven Liu <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: LRL-ModelCloud <[email protected]> Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: Qubitium-ModelCloud <[email protected]> Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: LRL <[email protected]> Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Mohamed Mekkouri <[email protected]> Co-authored-by: Steven Liu <[email protected]> * Fix : Nemotron Processor in GGUF conversion (huggingface#35708) * fixing nemotron processor * make style * Update docs/source/en/quantization/spqr.md Co-authored-by: Arthur <[email protected]> * Add missing TOC to doc --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: jiqing-feng <[email protected]> Co-authored-by: LRL-ModelCloud <[email protected]> Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: Qubitium-ModelCloud <[email protected]> Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: LRL <[email protected]> Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Mohamed Mekkouri <[email protected]> Co-authored-by: Arthur <[email protected]>
- Loading branch information
1 parent
c5506f4
commit 845b0a2
Showing
16 changed files
with
591 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
<!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# SpQR | ||
|
||
[SpQR](https://github.com/Vahe1994/SpQR) quantization algorithm involves a 16x16 tiled bi-level group 3-bit quantization structure, with sparse outliers as detailed in [SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression](https://arxiv.org/abs/2306.03078). | ||
|
||
To SpQR-quantize a model, refer to the [Vahe1994/SpQR](https://github.com/Vahe1994/SpQR) repository. | ||
|
||
Load a pre-SpQR-quantized model in [`~PreTrainedModel.from_pretrained`]. | ||
|
||
```python | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
import torch | ||
|
||
quantized_model = AutoModelForCausalLM.from_pretrained( | ||
"elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf", | ||
torch_dtype=torch.half, | ||
device_map="auto" | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained("elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"SpQR (Sparse-Quantized Representation) integration file" | ||
|
||
from ..utils import is_accelerate_available, is_spqr_available, is_torch_available | ||
|
||
|
||
if is_torch_available(): | ||
import torch.nn as nn | ||
|
||
|
||
def replace_with_spqr_linear( | ||
model, | ||
quantization_config=None, | ||
modules_to_not_convert=None, | ||
current_key_name=None, | ||
has_been_replaced=False, | ||
): | ||
""" | ||
Public method that recursively replaces the Linear layers of the given model with SpQR quantized layers. | ||
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the | ||
conversion has been successful or not. | ||
Args: | ||
model (`torch.nn.Module`): | ||
The model to convert, can be any `torch.nn.Module` instance. | ||
quantization_config (`SpQRConfig`): | ||
The quantization config object that contains the quantization parameters. | ||
modules_to_not_convert (`list[str]`, *optional*): | ||
A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be | ||
converted. | ||
current_key_name (`list`, *optional*): | ||
A list that contains the current key name. This is used for recursion and should not be passed by the user. | ||
has_been_replaced (`bool`, *optional*): | ||
A boolean that indicates if the conversion has been successful or not. This is used for recursion and | ||
should not be passed by the user. | ||
""" | ||
if modules_to_not_convert is None: | ||
modules_to_not_convert = [] | ||
|
||
if is_accelerate_available(): | ||
from accelerate import init_empty_weights | ||
if is_spqr_available(): | ||
from spqr_quant import QuantizedLinear | ||
|
||
for name, module in model.named_children(): | ||
if current_key_name is None: | ||
current_key_name = [] | ||
current_key_name.append(name) | ||
|
||
if isinstance(module, nn.Linear): | ||
# Check if the current key is not in the `modules_to_not_convert` | ||
if ".".join(current_key_name) + ".weight" not in modules_to_not_convert: | ||
with init_empty_weights(): | ||
tensor_name = ".".join(current_key_name) | ||
|
||
shapes = quantization_config.shapes | ||
shapes_keys = shapes.keys() | ||
|
||
shapes_valid = ( | ||
f"{tensor_name}.dense_weights.shape" in shapes_keys | ||
and f"{tensor_name}.row_offsets.shape" in shapes_keys | ||
and f"{tensor_name}.col_vals.shape" in shapes_keys | ||
and f"{tensor_name}.in_perm.shape" in shapes_keys | ||
) | ||
|
||
if not shapes_valid: | ||
raise ValueError( | ||
f"The SpQR quantization config does not contain the shape " | ||
f"configuration for {tensor_name}. This indicates that the " | ||
f"configuration is either invalid or corrupted." | ||
) | ||
|
||
dense_weights_shape = shapes[f"{tensor_name}.dense_weights.shape"] | ||
row_offsets_shape = shapes[f"{tensor_name}.row_offsets.shape"] | ||
col_vals_shape = shapes[f"{tensor_name}.col_vals.shape"] | ||
in_perm_shape = shapes[f"{tensor_name}.in_perm.shape"] | ||
|
||
in_features = module.in_features | ||
out_features = module.out_features | ||
|
||
model._modules[name] = QuantizedLinear.create_placehodler( | ||
rows=out_features, | ||
cols=in_features, | ||
bits=quantization_config.bits, | ||
beta1=quantization_config.beta1, | ||
beta2=quantization_config.beta2, | ||
dense_weights_shape=dense_weights_shape, | ||
row_offsets_shape=row_offsets_shape, | ||
col_vals_shape=col_vals_shape, | ||
in_perm_shape=in_perm_shape, | ||
) | ||
has_been_replaced = True | ||
|
||
# Store the module class in case we need to transpose the weight later | ||
model._modules[name].source_cls = type(module) | ||
# Force requires grad to False to avoid unexpected errors | ||
model._modules[name].requires_grad_(False) | ||
else: | ||
pass | ||
if len(list(module.children())) > 0: | ||
_, has_been_replaced = replace_with_spqr_linear( | ||
module, | ||
quantization_config=quantization_config, | ||
modules_to_not_convert=modules_to_not_convert, | ||
current_key_name=current_key_name, | ||
has_been_replaced=has_been_replaced, | ||
) | ||
# Remove the last key for recursion | ||
current_key_name.pop(-1) | ||
return model, has_been_replaced |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/lic enses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from .base import HfQuantizer | ||
|
||
|
||
if TYPE_CHECKING: | ||
from ..modeling_utils import PreTrainedModel | ||
|
||
from ..integrations import replace_with_spqr_linear | ||
from ..utils import is_accelerate_available, is_spqr_available, is_torch_available, logging | ||
from ..utils.quantization_config import QuantizationConfigMixin | ||
|
||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class SpQRHfQuantizer(HfQuantizer): | ||
""" | ||
Quantizer of the SpQR method. Enables the loading of prequantized models. | ||
""" | ||
|
||
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): | ||
super().__init__(quantization_config, **kwargs) | ||
self.quantization_config = quantization_config | ||
|
||
def validate_environment(self, *args, **kwargs): | ||
if not torch.cuda.is_available(): | ||
raise RuntimeError("GPU is required to run SpQR quantized model.") | ||
|
||
if not is_accelerate_available(): | ||
raise ImportError("Using `spqr` quantization requires Accelerate: `pip install accelerate`") | ||
|
||
if not is_spqr_available(): | ||
raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant[gpu]`") | ||
|
||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | ||
if torch_dtype is None: | ||
torch_dtype = torch.float16 | ||
logger.info("Assuming SpQR inference on GPU and loading the model in `torch.float16`.") | ||
elif torch_dtype != torch.float16: | ||
raise ValueError( | ||
"You cannot use any type other than torch.float16 for SpQR. Please either leave it None or set it to" | ||
"torch.float16 explicitly." | ||
) | ||
return torch_dtype | ||
|
||
def _process_model_before_weight_loading( | ||
self, | ||
model: "PreTrainedModel", | ||
**kwargs, | ||
): | ||
replace_with_spqr_linear( | ||
model, | ||
quantization_config=self.quantization_config, | ||
modules_to_not_convert=self.quantization_config.modules_to_not_convert, | ||
) | ||
model.config.quantization_config = self.quantization_config | ||
|
||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): | ||
return model | ||
|
||
@property | ||
def is_trainable(self, model: Optional["PreTrainedModel"] = None): | ||
return False | ||
|
||
def is_serializable(self, safe_serialization=None): | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.