Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run model as compressed/uncompressed mode #34719

Merged
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
caa9d6b
draft, run model as compreszed/uncompressed mode
horheynm Nov 13, 2024
86a649d
draft
horheynm Nov 18, 2024
b28d1d2
run run_compressed=False
horheynm Nov 18, 2024
39afd39
Merge branch 'main' into compressed-tensors/run_compressed
horheynm Nov 18, 2024
bbe0b42
run_compressed as attr
horheynm Nov 19, 2024
99d2d8a
Merge branch 'compressed-tensors/run_compressed' of github.com:neural…
horheynm Nov 19, 2024
5bd706b
set run_compressed=False using quantization_config
horheynm Nov 22, 2024
70aaee0
remove redundant line
horheynm Nov 22, 2024
32e693b
make is_qat_trainable dependent on run_compressed status
horheynm Nov 22, 2024
4f06a78
add tests
horheynm Nov 25, 2024
edc6417
Merge branch 'main' into compressed-tensors/run_compressed
horheynm Nov 25, 2024
668421b
lint
horheynm Nov 25, 2024
d5a8940
Merge branch 'compressed-tensors/run_compressed' of github.com:neural…
horheynm Nov 25, 2024
d44e1c1
full in docstring
horheynm Nov 25, 2024
42cf70d
add decompress
horheynm Nov 26, 2024
068944c
comments
horheynm Dec 2, 2024
1cee2a2
Merge branch 'main' into compressed-tensors/run_compressed
horheynm Dec 2, 2024
0e6e339
decompress if model is compresssed and not run_compressed
horheynm Dec 2, 2024
131225b
Merge branch 'main' into compressed-tensors/run_compressed
horheynm Dec 3, 2024
18371bc
apply_quant_config logic fix -- populate statedict properly
horheynm Dec 3, 2024
2370ea6
comments
horheynm Dec 4, 2024
dac41d2
remove non compressed model
horheynm Dec 5, 2024
01e9ca7
make is_compressed as property
horheynm Dec 5, 2024
3599a27
cosmetic
horheynm Dec 6, 2024
4450e2d
Merge branch 'main' into compressed-tensors/run_compressed
horheynm Dec 6, 2024
331832e
Merge branch 'main' into compressed-tensors/run_compressed
horheynm Dec 9, 2024
4391525
run apply_quant_config for non-compressed models -- popualte scales a…
horheynm Dec 9, 2024
d267da1
Merge branch 'compressed-tensors/run_compressed' of github.com:neural…
horheynm Dec 9, 2024
941af7e
add pahtway for decompressing sparse models
horheynm Dec 10, 2024
d3c418e
typo on is_quantization_compressed
horheynm Dec 10, 2024
3419e4c
Merge branch 'main' into compressed-tensors/run_compressed
horheynm Dec 11, 2024
3ca6ade
lint
horheynm Dec 11, 2024
2e7ef0a
Merge branch 'compressed-tensors/run_compressed' of github.com:neural…
horheynm Dec 11, 2024
d1d28e7
fix typo
horheynm Dec 11, 2024
6759933
Merge branch 'main' into compressed-tensors/run_compressed
horheynm Dec 11, 2024
9d2f2ec
Merge branch 'main' into compressed-tensors/run_compressed
horheynm Dec 12, 2024
c44c513
Merge branch 'main' into compressed-tensors/run_compressed
horheynm Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3622,7 +3622,12 @@ def from_pretrained(
)
else:
config.quantization_config = quantization_config
hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized)

hf_quantizer = AutoHfQuantizer.from_config(
config.quantization_config,
pre_quantized=pre_quantized,
)

else:
hf_quantizer = None

Expand Down Expand Up @@ -4303,7 +4308,7 @@ def from_pretrained(
dispatch_model(model, **device_map_kwargs)

if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
hf_quantizer.postprocess_model(model, config=config)
model.hf_quantizer = hf_quantizer

if _adapter_model_path is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,14 @@ def merge_quantization_configs(
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)

if (
isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config))
isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config, CompressedTensorsConfig))
and quantization_config_from_args is not None
):
# special case for GPTQ / AWQ / FbgemmFp8 config collision
loading_attr_dict = quantization_config_from_args.get_loading_attributes()
for attr, val in loading_attr_dict.items():
setattr(quantization_config, attr, val)

warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."

if warning_msg != "":
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/quantizer_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwarg
" Please double check your model architecture, or submit an issue on github if you think this is a bug."
)

def _process_model_after_weight_loading(self, model):
def _process_model_after_weight_loading(self, model, **kwargs):
if self.quantization_config.do_fuse:
from ..integrations import fuse_awq_modules

Expand Down
58 changes: 48 additions & 10 deletions src/transformers/quantizers/quantizer_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import os

from ..utils import is_compressed_tensors_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
from ..utils.quantization_config import CompressedTensorsConfig
from .base import HfQuantizer


Expand All @@ -32,12 +35,13 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
requires_calibration = True
required_packages = ["compressed_tensors"]

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
def __init__(self, quantization_config: CompressedTensorsConfig, **kwargs):
super().__init__(quantization_config, **kwargs)

from compressed_tensors.compressors import ModelCompressor

self.compressor = ModelCompressor.from_compression_config(quantization_config)
self.run_compressed = quantization_config.run_compressed
self.quantization_config = quantization_config

def validate_environment(self, *args, **kwargs):
if not is_compressed_tensors_available():
Expand All @@ -63,20 +67,54 @@ def _process_model_before_weight_loading(self, model, **kwargs):
from compressed_tensors.quantization import apply_quantization_config

ct_quantization_config = self.compressor.quantization_config
apply_quantization_config(model, ct_quantization_config, run_compressed=True)

def _process_model_after_weight_loading(self, model, **kwargs) -> None:
pass
if self.run_compressed and self.is_compressed:
apply_quantization_config(model, ct_quantization_config, run_compressed=True)

def _process_model_after_weight_loading(self, model, **kwargs):
"""Decompress loaded model if necessary - need for qat"""

if self.is_compressed and not self.run_compressed:
config = kwargs.get("config", None)
cache_path = config._name_or_path
if not os.path.exists(cache_path):
from huggingface_hub import hf_hub_download

from transformers import TRANSFORMERS_CACHE
from transformers.utils import http_user_agent

user_agent = http_user_agent()
config_file_path = hf_hub_download(
repo_id=cache_path,
filename="config.json",
cache_dir=TRANSFORMERS_CACHE,
force_download=False,
user_agent=user_agent,
)
cache_path = os.path.sep.join(config_file_path.split(os.path.sep)[:-1])

from compressed_tensors.quantization import QuantizationStatus

self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
self.compressor.decompress(model_path=cache_path, model=model)

@property
def is_trainable(self) -> bool:
"""Models quantized using compressed tensors can be finetuned"""
return True
def is_compressed(self):
from compressed_tensors.quantization import QuantizationStatus

return (
self.quantization_config.quantization_config is not None
and self.quantization_config.quantization_config.quantization_status == QuantizationStatus.COMPRESSED
)

@property
def is_trainable(self):
return True

def is_qat_trainable(self) -> bool:
"""Loaded Models can carry out quantization aware training"""
return True
# models need to be decompressed carry out qat
return not self.run_compressed or not self.is_compressed

def is_serializable(self, safe_serialization=None) -> bool:
"""Models quantized using compressed tensors can be saved to disk"""
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/quantizer_quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _process_model_before_weight_loading(
)
model.config.quantization_config = self.quantization_config

def _process_model_after_weight_loading(self, model):
def _process_model_after_weight_loading(self, model, **kwargs):
return model

@property
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def create_quantized_param(
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())

def _process_model_after_weight_loading(self, model):
def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""
return

Expand Down
15 changes: 13 additions & 2 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,8 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
config_groups (`typing.Dict[str, typing.Union[ForwardRef('QuantizationScheme'), typing.List[str]]]`, *optional*):
dictionary mapping group name to a quantization scheme definition
format (`str`, *optional*, defaults to `"dense"`):
format the model is represented as
format the model is represented as. Set `run_compressed` True to execute model as the
compressed format if not `dense`
quantization_status (`QuantizationStatus`, *optional*, defaults to `"initialized"`):
status of model in the quantization lifecycle, ie 'initialized', 'calibration', 'frozen'
kv_cache_scheme (`typing.Union[QuantizationArgs, NoneType]`, *optional*):
Expand All @@ -1090,6 +1091,8 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
configuration for sparsity compression
quant_method (`str`, *optional*, defaults to `"compressed-tensors"`):
do not override, should be compressed-tensors
run_compressed (`bool`, *optional*, defaults to `True`): alter submodules (usually linear) in order to
emulate compressed model execution if True, otherwise use default submodule
"""

def __init__(
Expand All @@ -1102,14 +1105,17 @@ def __init__(
ignore: Optional[List[str]] = None,
sparsity_config: Dict[str, Any] = None,
quant_method: str = "compressed-tensors",
run_compressed: bool = True,
**kwargs,
):
from compressed_tensors import QuantizationConfig
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationConfig

self.quantization_config = None
self.sparsity_config = None

self.run_compressed = run_compressed

# parse from dict to load nested QuantizationScheme objects
if config_groups or kv_cache_scheme:
self.quantization_config = QuantizationConfig.parse_obj(
Expand All @@ -1121,6 +1127,7 @@ def __init__(
"kv_cache_scheme": kv_cache_scheme,
"global_compression_ratio": global_compression_ratio,
"ignore": ignore,
"run_compressed": run_compressed,
**kwargs,
}
)
Expand Down Expand Up @@ -1149,6 +1156,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):

Returns:
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.

"""

if "quantization_config" in config_dict:
Expand Down Expand Up @@ -1200,6 +1208,9 @@ def to_diff_dict(self) -> Dict[str, Any]:

return serializable_config_dict

def get_loading_attributes(self):
return {"run_compressed": self.run_compressed}


@dataclass
class FbgemmFp8Config(QuantizationConfigMixin):
Expand Down
94 changes: 94 additions & 0 deletions tests/quantization/compressed_tensor/test_run_compressed_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import gc
import unittest

from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_compressed_tensors, require_torch
from transformers.utils import is_torch_available


if is_torch_available():
import torch


@require_compressed_tensors
@require_torch
class CompressedTensorsTest(unittest.TestCase):
tinyllama_w4a16 = "nm-testing/tinyllama-w4a16-compressed-hf-quantizer"
tinyllama_w8a8 = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer"

prompt = "Paris is the capital of which country?"

stubs = [tinyllama_w4a16, tinyllama_w8a8]

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()

def test_default_run_compressed__True(self):
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization.utils import iter_named_leaf_modules

for stub in self.stubs:
model = AutoModelForCausalLM.from_pretrained(
stub,
)
compressed_linear_counts = 0

for _, submodule in iter_named_leaf_modules(
model,
):
if isinstance(submodule, CompressedLinear):
compressed_linear_counts += 1

# some linear models are not compressed - ex. lm_head
assert compressed_linear_counts > 0

def test_default_run_compressed__False(self):
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization.utils import iter_named_leaf_modules

from transformers.utils.quantization_config import CompressedTensorsConfig

quantization_config = CompressedTensorsConfig(run_compressed=False)

for stub in self.stubs:
model = AutoModelForCausalLM.from_pretrained(
stub,
quantization_config=quantization_config,
)
compressed_linear_counts = 0

for _, submodule in iter_named_leaf_modules(
model,
):
if isinstance(submodule, CompressedLinear):
compressed_linear_counts += 1

# No modules should be CompressedLinear
assert compressed_linear_counts == 0

def test_run_compressed_outputs_match(self):
"""Check that run_compressed=True/False output are the same"""

from transformers import AutoTokenizer
from transformers.utils.quantization_config import CompressedTensorsConfig

quantization_config = CompressedTensorsConfig(run_compressed=False)

for stub in self.stubs:
tokenizer = AutoTokenizer.from_pretrained(stub)
input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids

model_run_compressed__True = AutoModelForCausalLM.from_pretrained(
stub,
)
output_rc_true = model_run_compressed__True.generate(input_ids, max_new_tokens=100)

model_run_compressed__False = AutoModelForCausalLM.from_pretrained(
stub,
quantization_config=quantization_config,
)
output_rc_false = model_run_compressed__False.generate(input_ids, max_new_tokens=100)

assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])