Skip to content

Commit

Permalink
Merge branch 'main' into quant-then-finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka authored Jan 22, 2025
2 parents 6e7bfa6 + d5984db commit 6884b78
Show file tree
Hide file tree
Showing 12 changed files with 671 additions and 310 deletions.
13 changes: 12 additions & 1 deletion examples/sparse_2of4_quantization_fp8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ oneshot(
)
```

3. **Save the Compressed Model**
### Saving the Compressed Model

The compressed model and tokenizer are saved to the output directory:

Expand All @@ -106,6 +106,17 @@ Output Directories:
- Without FP8: `Meta-Llama-3-8B-Instruct-2of4-sparse`
- With FP8: `Meta-Llama-3-8B-Instruct-2of4-W8A8-FP8-Dynamic-Per-Token`

#### Saving Without Sparse Compression

To save the model on disk without sparse compression:

```python
model.save_pretrained(save_dir, save_compressed=True, disable_sparse_compression=True)
tokenizer.save_pretrained(save_dir)
```

> **Note:** Saving a model with both the `save_compressed` and `disable_sparse_compression` options will compress the model using the quantization compressor; however, instead of using the more disk-efficient sparsity compressor(s), the dense sparsity compressor will be used. The `dense` sparsity compressor saves model params as is, and does not leverage sparsity for disk-efficient storage. These options only affect how the model(s) are saved on disk and do not impact the actual pruning or quantization processes.
### Validation

After compression, the script validates the model by generating a sample output:
Expand Down
28 changes: 24 additions & 4 deletions src/llmcompressor/transformers/compression/quantization_format.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from compressed_tensors import CompressionFormat
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.config import SparsityStructure
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
from compressed_tensors.quantization.utils import (
is_model_quantized,
Expand All @@ -16,10 +16,30 @@ def infer_quantization_format(
model,
quantization_format: Optional[str] = None,
save_compressed: bool = False,
sparsity_config: Optional[SparsityCompressionConfig] = None,
sparsity_structure: Optional[str] = None,
) -> str:
"""
Infers a quantization format based on model state and compression args
Infers the quantization format for a model based on its state and provided
compression arguments.
The following table outlines the possible quantization and sparsity formats
along with their corresponding compressor formats:
+---------------+----------+----------------------+---------------------+
| Quantization | Sparsity | Quant Compressor | Sparsity Compressor |
| | | Format | Format |
+---------------+----------+----------------------+---------------------+
| W8A8 - int | None | int_quantized | Dense |
| W8A8 - float | None | float_quantized | Dense |
| W4A16 - int | None | pack_quantized | Dense |
| W8A16 - int | None | pack_quantized | Dense |
| W8A16 - float | None | naive_quantized | Dense |
| W8A8 - int | 2:4 | int_quantized | Sparse24 |
| W8A8 - float | 2:4 | float_quantized | Sparse24 |
| W4A16 - int | 2:4 | marlin_24 | Dense |
| W8A16 - int | 2:4 | marlin_24 | Dense |
| W8A16 - float | 2:4 | naive_quantized | Dense |
+---------------+----------+----------------------+---------------------+
:param model: model to check for quantization, if the model is not quantized no
quantization format is returned
Expand All @@ -37,7 +57,7 @@ def infer_quantization_format(
if save_compressed:
weight_args, input_args = _get_unique_quant_args(model)
is_24_structure = (
sparsity_config and sparsity_config.sparsity_structure == "2:4"
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
)
is_weight_only = len(input_args) == 0 and len(weight_args) > 0

Expand Down
102 changes: 94 additions & 8 deletions src/llmcompressor/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Dict, Optional
from typing import Dict, List, Optional

from compressed_tensors import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization.utils import is_model_quantized
from compressed_tensors.config import SparsityStructure
from compressed_tensors.quantization import QuantizationType
from compressed_tensors.quantization.utils import (
is_model_quantized,
is_module_quantized,
iter_named_leaf_modules,
)
from loguru import logger
from torch import Tensor
from torch.nn import Module

Expand All @@ -20,7 +27,7 @@ class SparsityConfigMetadata:
metadata from the model
"""

SPARSITY_THRESHOLD: float = 0.4
SPARSITY_THRESHOLD: float = 0.49

@staticmethod
def infer_global_sparsity(
Expand Down Expand Up @@ -67,13 +74,15 @@ def infer_sparsity_structure(model: Optional[Module] = None) -> str:
if model and sparsity_structure is None:
sparsity_structure = infer_sparsity_structure_from_model(model)

return sparsity_structure or "unstructured"
return SparsityStructure(sparsity_structure).value

@staticmethod
def from_pretrained(
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
compress: bool = False,
quantization_format: Optional[CompressionFormat] = None,
disable_sparse_compression: bool = False,
) -> Optional["SparsityCompressionConfig"]:
"""
Determines compression type and informational parameters for a given model
Expand All @@ -82,6 +91,11 @@ def from_pretrained(
:param state_dict: optional state_dict to replace that in model, used for
gathering global FSDP model info
:param compress: whether or not to compress the model on disk
:param quantization_format: the quantization compression format being used
for the model
:param disable_sparse_compression: whether or not to compress the model with
sparse compressors, If True, the sparse compression format will
be dense, default is False.
:return: compression config inferred from the model
"""

Expand All @@ -95,11 +109,18 @@ def from_pretrained(
sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(
model=model
)
if is_model_quantized(model):
# compressing a sparse quantized model is not supported yet
if (
disable_sparse_compression
or quantization_format == CompressionFormat.marlin_24
):
# sparse compressor should be dense
# when no_sparse_compression is True
# or when marlin_24 is used
format = CompressionFormat.dense.value
elif compress:
format = CompressionFormat.sparse_bitmask.value
elif compress and SparsityConfigMetadata.is_sparse24_bitmask_supported(
model, sparsity_structure
):
format = CompressionFormat.sparse_24_bitmask.value
else:
format = CompressionFormat.dense.value

Expand Down Expand Up @@ -135,3 +156,68 @@ def fill_config_details(
model, state_dict=state_dict
)
config.sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure()

@staticmethod
def is_sparse24_bitmask_supported(
model: Module,
sparsity_structure: Optional[str] = None,
) -> bool:
"""
Determines if sparse 24 bitmask sparse compressor is supported for a given model
and its sparsity structure in vLLM
:param model: pytorch model to check for sparse 24 bit sparsity support
:param sparsity_structure: sparsity structure of the model, if
not supplied it will be inferred
:return: whether or not sparse 24 bitmask compression is supported
in vLLM for the given model
"""

if sparsity_structure is None:
sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(model)

if sparsity_structure != SparsityStructure.TWO_FOUR.value:
# only supported for 2:4 sparsity
return False

if not is_model_quantized(model):
# non-quantized 2:4 sparse models are supported
return True

# when model is quantized, and has 2:4 sparsity

supported_scheme_types: List[str] = [
QuantizationType.INT.value,
QuantizationType.FLOAT.value,
]

for _, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
weight_scheme = submodule.quantization_scheme.weights
input_scheme = submodule.quantization_scheme.input_activations

if weight_scheme and input_scheme:
# weight and activation quantization
# check schemes are supported
for scheme in [weight_scheme, input_scheme]:
scheme_supported = (
scheme.num_bits == 8
and scheme.type in supported_scheme_types
)
if not scheme_supported:
logger.info(
"Quantization scheme not supported,"
" turning off sparse 24 compression."
f" Invalid Scheme: {scheme}"
)
return False

elif weight_scheme or input_scheme:
# weight only quantization
logger.info(
"Weight only quantization detected, "
"turning off sparse 24 compression."
)
return False

return True
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import transformers
from accelerate.accelerator import get_state_dict_offloaded_model
from compressed_tensors import (
CompressionFormat,
ModelCompressor,
SparsityCompressionConfig,
is_module_offloaded,
Expand Down Expand Up @@ -124,6 +125,7 @@ def save_pretrained_wrapper(
quantization_format: Optional[str] = None,
save_compressed: bool = True,
skip_compression_stats: bool = False,
disable_sparse_compression: bool = False,
**kwargs,
):
"""
Expand All @@ -133,13 +135,15 @@ def save_pretrained_wrapper(
:param save_directory: output directory to save model to
:param sparsity_config: optional sparsity config to compress model with,
if no config is provided it will be inferred from the model
if no config is provided it will be inferred from the model
:param quantization_format: optional compression format for quantized
models. If none is provided it will be inferred from the model
models. If none is provided it will be inferred from the model
:param save_compressed: whether or not to compress the model on disk
:param skip_compression_stats: whether to skip the calculation of
compression statistics (such as global sparsity and sparsity structure) when
saving a model in dense format
compression statistics (such as global sparsity and sparsity structure)
when saving a model in dense format
:param disable_sparse_compression: whether to skip sparse compression
during save, default is False
:param kwargs: additional kwargs to pass on to model.save_pretrained
"""

Expand Down Expand Up @@ -169,6 +173,7 @@ def skip(*args, **kwargs):
save_compressed=save_compressed,
skip_compression_stats=skip_compression_stats,
state_dict=state_dict,
disable_sparse_compression=disable_sparse_compression,
)

if compressor is None:
Expand Down Expand Up @@ -260,6 +265,7 @@ def get_model_compressor(
save_compressed: bool = True,
skip_compression_stats: bool = False,
state_dict: Optional[Dict] = None,
disable_sparse_compression: bool = False,
):
"""
Obtain the compressor based on the config and the
Expand All @@ -273,19 +279,26 @@ def get_model_compressor(
format
:param skip_compression_stats: bool allowing compression stats on std out
:param state_dict: state_dict of the model
:param disable_sparse_compression: bool to skip sparse compression
"""

# find offloaded state dict if none is provided
if state_dict is None:
state_dict = get_state_dict_offloaded_model(model)

sparsity_stucture = SparsityConfigMetadata.infer_sparsity_structure(model)
quantization_format: Optional[CompressionFormat] = infer_quantization_format(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
sparsity_structure=sparsity_stucture,
)

if sparsity_config is not None:
sparsity_config.global_sparsity = SparsityConfigMetadata.infer_global_sparsity(
model, state_dict=state_dict
)
sparsity_config.sparsity_structure = (
SparsityConfigMetadata.infer_sparsity_structure()
)
sparsity_config.sparsity_structure = sparsity_stucture
elif not skip_compression_stats:
# try to infer a sparsity config from the model if none is provided
logger.info(
Expand All @@ -295,15 +308,13 @@ def get_model_compressor(
"skip_compression_stats=True"
)
sparsity_config = SparsityConfigMetadata.from_pretrained(
model, state_dict=state_dict, compress=save_compressed
model,
state_dict=state_dict,
compress=save_compressed,
quantization_format=quantization_format,
disable_sparse_compression=disable_sparse_compression,
)

quantization_format = infer_quantization_format(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
sparsity_config=sparsity_config,
)
return ModelCompressor.from_pretrained_model(
model,
sparsity_config=sparsity_config,
Expand Down
2 changes: 0 additions & 2 deletions src/llmcompressor/transformers/tracing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .llava import (
LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration,
)
from .mistral import MistralForCausalLM as TraceableMistralForCausalLM
from .mllama import (
MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration,
)
Expand All @@ -12,6 +11,5 @@
__all__ = [
"TraceableLlavaForConditionalGeneration",
"TraceableMllamaForConditionalGeneration",
"TraceableMistralForCausalLM",
"TraceableQwen2VLForConditionalGeneration",
]
24 changes: 1 addition & 23 deletions src/llmcompressor/transformers/tracing/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,14 @@
from typing import List, Optional, Tuple, Union

import torch
from transformers import AutoModel, AutoModelForCausalLM, LlavaForConditionalGeneration
from transformers import LlavaForConditionalGeneration
from transformers.models.llava.configuration_llava import LlavaConfig
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaMultiModalProjector,
LlavaPreTrainedModel,
logger,
)
from transformers.models.mistral.configuration_mistral import MistralConfig
from transformers.utils.fx import HFProxy

# TRACING: Reuse traceable subclass
from .mistral import MistralForCausalLM as TraceableMistralForCausalLM


# TRACING: The shape of image_features is known and documented by
# LlavaForConditionalGeneration.get_image_features
Expand Down Expand Up @@ -75,22 +69,6 @@ def maybe_install_metadata_inputs_embeds_masked(

# TRACING: override `__init__` and `forward`
class LlavaForConditionalGeneration(LlavaForConditionalGeneration):
def __init__(self, config: LlavaConfig):
super(LlavaPreTrainedModel, self).__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)

self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size

# TRACING: Must use TraceableMistralForCausalLM which wraps an untraceable function
if isinstance(config.text_config, MistralConfig):
self.language_model = TraceableMistralForCausalLM(config.text_config)
else:
self.language_model = AutoModelForCausalLM.from_config(config.text_config)

self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
Loading

0 comments on commit 6884b78

Please sign in to comment.