From 4093df7dab1bb547753068efb2c3e43301a60c9a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 13 Feb 2025 16:52:17 -0500 Subject: [PATCH 1/2] rename Signed-off-by: Kyle Sayers --- examples/automodelforcausallm/README.md | 13 + .../run_automodelforcausallm.py | 11 + .../example_alternating_recipe.yaml | 16 +- .../2of4_w4a16_group-128_recipe.yaml | 3 +- .../2of4_w4a16_recipe.yaml | 3 +- examples/quantizing_moe/deepseek_moe_w4a16.py | 4 - .../quantizing_moe/deepseek_moe_w8a8_fp8.py | 4 - .../quantizing_moe/deepseek_moe_w8a8_int8.py | 4 - .../sparse_2of4_quantization_fp8/README.md | 13 +- .../llama3_8b_2of4.py | 7 +- examples/trl_mixin/ex_trl_constant.py | 2 +- examples/trl_mixin/ex_trl_distillation.py | 9 +- examples/trl_mixin/sft_trainer.py | 2 +- src/llmcompressor/modifiers/README.md | 2 +- src/llmcompressor/modifiers/obcq/base.py | 165 +------- .../modifiers/pruning/__init__.py | 1 + .../modifiers/pruning/sparsegpt/__init__.py | 3 + .../modifiers/pruning/sparsegpt/base.py | 163 ++++++++ .../{obcq => pruning/sparsegpt}/sgpt_mixin.py | 0 .../sparsegpt}/sgpt_sparsify.py | 0 .../modifiers/pruning/wanda/base.py | 2 +- .../modifiers/quantization/gptq/base.py | 2 +- .../pipelines/sequential/README.md | 2 +- .../transformers/finetune/README.md | 9 +- .../transformers/tracing/GUIDE.md | 2 +- .../2of4_w4a16_group-128_recipe.yaml | 3 +- .../recipes/WNA16_2of4/2of4_w4a16_recipe.yaml | 3 +- tests/e2e/vLLM/test_vllm.py | 86 +--- tests/examples/utils.py | 7 +- .../modifiers/calibration/test_cache.py | 2 +- tests/llmcompressor/modifiers/conf.py | 6 +- .../modifiers/pruning/sparsegpt/test_base.py | 2 +- .../gptq/utils/test_gptq_wrapper.py | 41 ++ .../modifiers/smoothquant/test_utils.py | 5 +- .../modifiers/utils/test_hooks.py | 93 ----- tests/llmcompressor/observers/test_min_max.py | 8 +- tests/llmcompressor/observers/test_mse.py | 4 +- tests/llmcompressor/pytorch/helpers.py | 2 - .../pruning/sparsegpt/test_pytorch.py | 50 ++- tests/llmcompressor/recipe/test_recipe.py | 2 +- .../compression/recipes/sparse_24.yaml | 2 +- .../compression/recipes/sparse_24_fp8.yaml | 2 +- .../run_compressed_configs/fp8_dynamic.yaml | 4 +- .../run_compressed_configs/w4a16.yaml | 4 +- .../run_compressed_configs/w8a16_dense.yaml | 4 + .../run_compressed_configs/w8a8.yaml | 4 +- .../compression/test_infer_quant_format.py | 6 +- .../compression/test_quantization.py | 5 +- .../compression/test_run_compressed.py | 140 ++----- .../transformers/finetune/data/conftest.py | 2 +- .../finetune/data/test_dataset_helpers.py | 6 +- .../finetune/data/test_dataset_loading.py | 34 +- .../finetune/data/test_registry.py | 8 +- .../finetune/test_alternate_recipe.yaml | 6 +- .../test_finetune_oneshot_with_modifier.py | 2 +- .../finetune/test_oneshot_then_finetune.py | 100 +---- .../finetune/test_session_mixin.py | 4 - .../transformers/gptq/test_oneshot.py | 2 +- .../obcq/recipes/additional_sparsity.yaml | 3 +- .../additional_sparsity_with_quant.yaml | 3 +- .../transformers/obcq/recipes/quant.yaml | 2 +- .../obcq/recipes/quant_and_sparse.yaml | 3 +- .../transformers/obcq/recipes/sparse.yaml | 3 +- .../recipes/sparse_with_mask_structure.yaml | 3 +- .../transformers/obcq/recipes/test_tiny2.yaml | 3 +- .../obcq/test_consecutive_runs.py | 44 +- .../transformers/obcq/test_obcq_completion.py | 6 +- .../obcq/test_obcq_infer_targets.py | 28 +- .../transformers/obcq/test_obcq_lm_head.py | 39 +- .../transformers/obcq/test_obcq_owl.py | 2 +- .../transformers/obcq/test_sgpt_defaults.py | 23 ++ .../oneshot_configs/recipes/recipe.yaml | 3 +- .../oneshot_configs/tiny_stories_conf1.yaml | 3 +- .../oneshot_configs/tiny_stories_conf4.yaml | 3 +- .../transformers/oneshot/test_cli.py | 1 - .../test_compress_tensor_utils.py | 383 +----------------- tests/testing_utils.py | 69 +--- 77 files changed, 544 insertions(+), 1171 deletions(-) create mode 100644 examples/automodelforcausallm/README.md create mode 100644 examples/automodelforcausallm/run_automodelforcausallm.py create mode 100644 src/llmcompressor/modifiers/pruning/sparsegpt/__init__.py create mode 100644 src/llmcompressor/modifiers/pruning/sparsegpt/base.py rename src/llmcompressor/modifiers/{obcq => pruning/sparsegpt}/sgpt_mixin.py (100%) rename src/llmcompressor/modifiers/{obcq => pruning/sparsegpt}/sgpt_sparsify.py (100%) create mode 100644 tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py create mode 100644 tests/llmcompressor/transformers/compression/run_compressed_configs/w8a16_dense.yaml create mode 100644 tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py diff --git a/examples/automodelforcausallm/README.md b/examples/automodelforcausallm/README.md new file mode 100644 index 000000000..e40cb5c2a --- /dev/null +++ b/examples/automodelforcausallm/README.md @@ -0,0 +1,13 @@ +# Loading models using `AutoModelForCausalLM` + +Models quantized through `llm-compressor` can be loaded directly through +`AutoModelForCausalLM`. Note: this requires `transformers>=v4.45.0` and +`compressed-tensors>v0.6.0`. + +```python +from transformers import AutoModelForCausalLM + +MODEL_ID = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer" + +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto") +``` diff --git a/examples/automodelforcausallm/run_automodelforcausallm.py b/examples/automodelforcausallm/run_automodelforcausallm.py new file mode 100644 index 000000000..791b4d3d5 --- /dev/null +++ b/examples/automodelforcausallm/run_automodelforcausallm.py @@ -0,0 +1,11 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +MODEL_ID = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer" + +# Use the AutoModelForCausalLM to run the model +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) diff --git a/examples/finetuning/example_alternating_recipe.yaml b/examples/finetuning/example_alternating_recipe.yaml index 134b50866..a3be682a4 100644 --- a/examples/finetuning/example_alternating_recipe.yaml +++ b/examples/finetuning/example_alternating_recipe.yaml @@ -1,13 +1,15 @@ initial_sparsity_stage: run_type: oneshot - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 + sequential_update: False percdamp: 0.01 mask_structure: "0:0" - targets: ["Linear"] - ignore: ["re:.*lm_head"] + targets: [ + "re:model.layers.\\d+$" + ] initial_training_stage: run_type: train pruning_modifiers: @@ -16,14 +18,16 @@ initial_training_stage: start: 0 next_sparsity_stage: run_type: oneshot - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.7 block_size: 128 + sequential_update: False percdamp: 0.01 mask_structure: "0:0" - targets: ["Linear"] - ignore: ["re:.*lm_head"] + targets: [ + "re:model.layers.\\d+$" + ] next_training_stage: run_type: train pruning_modifiers: diff --git a/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_group-128_recipe.yaml b/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_group-128_recipe.yaml index e59cf8a96..166e41a66 100644 --- a/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_group-128_recipe.yaml +++ b/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_group-128_recipe.yaml @@ -4,8 +4,7 @@ sparsity_stage: SparseGPTModifier: sparsity: 0.5 mask_structure: "2:4" - targets: ["Linear"] - ignore: ["re:.*lm_head"] + sequential_update: false finetuning_stage: run_type: train finetuning_modifiers: diff --git a/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_recipe.yaml b/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_recipe.yaml index 4ff5ff26e..2ad00b457 100644 --- a/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_recipe.yaml +++ b/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_recipe.yaml @@ -4,8 +4,7 @@ sparsity_stage: SparseGPTModifier: sparsity: 0.5 mask_structure: "2:4" - targets: ["Linear"] - ignore: ["re:.*lm_head"] + sequential_update: false finetuning_stage: run_type: train finetuning_modifiers: diff --git a/examples/quantizing_moe/deepseek_moe_w4a16.py b/examples/quantizing_moe/deepseek_moe_w4a16.py index 55a7021b4..3d7d33099 100644 --- a/examples/quantizing_moe/deepseek_moe_w4a16.py +++ b/examples/quantizing_moe/deepseek_moe_w4a16.py @@ -5,10 +5,6 @@ from llmcompressor.transformers import oneshot from llmcompressor.transformers.compression.helpers import calculate_offload_device_map -# NOTE: transformers 4.48.0 has an import error with DeepSeek. -# Please consider either downgrading your transformers version to a -# previous version or upgrading to a version where this bug is fixed - # select a Mixture of Experts model for quantization MODEL_ID = "deepseek-ai/DeepSeek-V2.5" diff --git a/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py b/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py index cda202eb9..666da8f9a 100644 --- a/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py +++ b/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py @@ -4,10 +4,6 @@ from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.transformers import oneshot -# NOTE: transformers 4.48.0 has an import error with DeepSeek. -# Please consider either downgrading your transformers version to a -# previous version or upgrading to a version where this bug is fixed - # select a Mixture of Experts model for quantization MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" diff --git a/examples/quantizing_moe/deepseek_moe_w8a8_int8.py b/examples/quantizing_moe/deepseek_moe_w8a8_int8.py index 289f4234f..ba215aa9e 100644 --- a/examples/quantizing_moe/deepseek_moe_w8a8_int8.py +++ b/examples/quantizing_moe/deepseek_moe_w8a8_int8.py @@ -6,10 +6,6 @@ from llmcompressor.transformers import oneshot from llmcompressor.transformers.compression.helpers import calculate_offload_device_map -# NOTE: transformers 4.48.0 has an import error with DeepSeek. -# Please consider either downgrading your transformers version to a -# previous version or upgrading to a version where this bug is fixed - # select a Mixture of Experts model for quantization MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" diff --git a/examples/sparse_2of4_quantization_fp8/README.md b/examples/sparse_2of4_quantization_fp8/README.md index 99fc3c545..97b8e590e 100644 --- a/examples/sparse_2of4_quantization_fp8/README.md +++ b/examples/sparse_2of4_quantization_fp8/README.md @@ -93,7 +93,7 @@ oneshot( ) ``` -### Saving the Compressed Model +3. **Save the Compressed Model** The compressed model and tokenizer are saved to the output directory: @@ -106,17 +106,6 @@ Output Directories: - Without FP8: `Meta-Llama-3-8B-Instruct-2of4-sparse` - With FP8: `Meta-Llama-3-8B-Instruct-2of4-W8A8-FP8-Dynamic-Per-Token` -#### Saving Without Sparse Compression - -To save the model on disk without sparse compression: - -```python -model.save_pretrained(save_dir, save_compressed=True, disable_sparse_compression=True) -tokenizer.save_pretrained(save_dir) -``` - -> **Note:** Saving a model with both the `save_compressed` and `disable_sparse_compression` options will compress the model using the quantization compressor; however, instead of using the more disk-efficient sparsity compressor(s), the dense sparsity compressor will be used. The `dense` sparsity compressor saves model params as is, and does not leverage sparsity for disk-efficient storage. These options only affect how the model(s) are saved on disk and do not impact the actual pruning or quantization processes. - ### Validation After compression, the script validates the model by generating a sample output: diff --git a/examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py b/examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py index 21cec66c8..e8133225f 100644 --- a/examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py +++ b/examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py @@ -3,8 +3,7 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer -from llmcompressor.modifiers.obcq import SparseGPTModifier -from llmcompressor.modifiers.pruning import ConstantPruningModifier +from llmcompressor.modifiers.pruning import ConstantPruningModifier, SparseGPTModifier from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.transformers import oneshot @@ -116,7 +115,5 @@ def get_recipe(fp8_enabled): print("==========================================\n") # Save compressed model and tokenizer -model.save_pretrained( - save_dir, save_compressed=args.fp8, disable_sparse_compression=True -) +model.save_pretrained(save_dir, save_compressed=args.fp8) tokenizer.save_pretrained(save_dir) diff --git a/examples/trl_mixin/ex_trl_constant.py b/examples/trl_mixin/ex_trl_constant.py index 517d74d71..b2f597ec8 100644 --- a/examples/trl_mixin/ex_trl_constant.py +++ b/examples/trl_mixin/ex_trl_constant.py @@ -3,7 +3,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from trl import DataCollatorForCompletionOnlyLM -from llmcompressor.args import TrainingArguments +from llmcompressor.transformers import TrainingArguments model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data" diff --git a/examples/trl_mixin/ex_trl_distillation.py b/examples/trl_mixin/ex_trl_distillation.py index 96cc78846..ff3ddf000 100644 --- a/examples/trl_mixin/ex_trl_distillation.py +++ b/examples/trl_mixin/ex_trl_distillation.py @@ -1,8 +1,11 @@ from sft_trainer import SFTTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator -from llmcompressor.args import DatasetArguments, TrainingArguments -from llmcompressor.transformers import TextGenerationDataset +from llmcompressor.transformers import ( + DataTrainingArguments, + TextGenerationDataset, + TrainingArguments, +) model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" teacher_path = "neuralmagic/Llama-2-7b-gsm8k" @@ -18,7 +21,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_path) # Load gsm8k using SparseML dataset tools -data_args = DatasetArguments( +data_args = DataTrainingArguments( dataset="gsm8k", dataset_config_name="main", max_seq_length=512 ) dataset_manager = TextGenerationDataset.load_from_registry( diff --git a/examples/trl_mixin/sft_trainer.py b/examples/trl_mixin/sft_trainer.py index 2577c0cc7..c311cf8dc 100644 --- a/examples/trl_mixin/sft_trainer.py +++ b/examples/trl_mixin/sft_trainer.py @@ -1,7 +1,7 @@ from trl import SFTConfig as TRLSFTConfig from trl import SFTTrainer as TRLSFTTrainer -from llmcompressor.args import TrainingArguments +from llmcompressor.transformers import TrainingArguments from llmcompressor.transformers.finetune.session_mixin import SessionManagerMixIn __all__ = ["SFTTrainer"] diff --git a/src/llmcompressor/modifiers/README.md b/src/llmcompressor/modifiers/README.md index 77a4cd425..72ff0b058 100644 --- a/src/llmcompressor/modifiers/README.md +++ b/src/llmcompressor/modifiers/README.md @@ -8,7 +8,7 @@ are relevant only during training. Below is a summary of the key modifiers avail Modifiers that introduce sparsity into a model -### [SparseGPT](./obcq/base.py) +### [SparseGPT](./pruning/sparsegpt/base.py) One-shot algorithm that uses calibration data to introduce unstructured or structured sparsity into weights. Implementation based on [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774). A small amount of calibration data is used to calculate a Hessian for each layers input activations, this Hessian is then used to diff --git a/src/llmcompressor/modifiers/obcq/base.py b/src/llmcompressor/modifiers/obcq/base.py index bcbd610fe..cbd4c0e09 100644 --- a/src/llmcompressor/modifiers/obcq/base.py +++ b/src/llmcompressor/modifiers/obcq/base.py @@ -1,163 +1,12 @@ -import contextlib -from typing import Dict, Optional, Tuple +import warnings -import torch -from compressed_tensors.utils import ( - align_module_device, - get_execution_device, - update_offload_parameter, -) -from loguru import logger -from pydantic import PrivateAttr +from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier -from llmcompressor.core import State -from llmcompressor.modifiers import Modifier -from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin -from llmcompressor.modifiers.obcq.sgpt_sparsify import ( - accumulate_hessian, - make_empty_hessian, - sparsify_weight, +warnings.warn( + "llmcompressor.modifiers.obcq has been moved to " + "llmcompressor.modifiers.pruning.sparsegpt Please update your paths", + DeprecationWarning, ) -from llmcompressor.utils.metric_logging import CompressionLogger - -__all__ = ["SparseGPTModifier"] - - -class SparseGPTModifier(SparsityModifierMixin, Modifier): - """ - Modifier for applying the one-shot SparseGPT algorithm to a model - - | Sample yaml: - | test_stage: - | obcq_modifiers: - | SparseGPTModifier: - | sparsity: 0.5 - | mask_structure: "2:4" - | dampening_frac: 0.001 - | block_size: 128 - | targets: ['Linear'] - | ignore: ['re:.*lm_head'] - - Lifecycle: - - on_initialize - - register_hook(module, calibrate_module, "forward") - - run_sequential / run_layer_sequential / run_basic - - make_empty_hessian - - accumulate_hessian - - on_sequential_batch_end - - sparsify_weight - - on_finalize - - remove_hooks() - - :param sparsity: Sparsity to compress model to - :param sparsity_profile: Can be set to 'owl' to use Outlier Weighed - Layerwise Sparsity (OWL), more information can be found - in the paper https://arxiv.org/pdf/2310.05175 - :param mask_structure: String to define the structure of the mask to apply. - Must be of the form N:M where N, M are integers that define a custom block - shape. Defaults to 0:0 which represents an unstructured mask. - :param owl_m: Number of outliers to use for OWL - :param owl_lmbda: Lambda value to use for OWL - :param block_size: Used to determine number of columns to compress in one pass - :param dampening_frac: Amount of dampening to apply to H, as a fraction of the - diagonal norm - :param preserve_sparsity_mask: Whether or not to preserve the sparsity mask - during when applying sparsegpt, this becomes useful when starting from a - previously pruned model, defaults to False. - :param offload_hessians: Set to True for decreased memory usage but increased - runtime. - :param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__' - to compress every layer in the model. Alias for `targets` - :param targets: list of layer names to compress during OBCQ, or '__ALL__' - to compress every layer in the model. Alias for `sequential_targets` - :param ignore: optional list of module class names or submodule names to not - quantize even if they match a target. Defaults to empty list. - """ - - # modifier arguments - block_size: int = 128 - dampening_frac: Optional[float] = 0.01 - preserve_sparsity_mask: bool = False - offload_hessians: bool = False - - # private variables - _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) - _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) - - def calibrate_module( - self, - module: torch.nn.Module, - args: Tuple[torch.Tensor, ...], - _output: torch.Tensor, - ): - # Assume that the first argument is the input - inp = args[0] - # Initialize hessian if not present - if module not in self._num_samples: - device = get_execution_device(module) - self._hessians[module] = make_empty_hessian(module, device=device) - self._num_samples[module] = 0 - # Accumulate hessian with input with optional offloading - with self._maybe_onload_hessian(module): - self._hessians[module], self._num_samples[module] = accumulate_hessian( - inp, - module, - self._hessians[module], - self._num_samples[module], - ) - - def on_sequential_batch_end(self): - """ - Sparsify modules - TODO: implement with event callback - """ - for module in list(self._num_samples.keys()): - name = self._module_names[module] - sparsity = self._module_sparsities[module] - num_samples = self._num_samples[module] - - logger.info(f"Sparsifying {name} using {num_samples} samples") - with ( - torch.no_grad(), - align_module_device(module), - CompressionLogger(module) as comp_logger, - ): - loss, sparsified_weight = sparsify_weight( - module=module, - hessians_dict=self._hessians, - sparsity=sparsity, - prune_n=self._prune_n, - prune_m=self._prune_m, - block_size=self.block_size, - dampening_frac=self.dampening_frac, - preserve_sparsity_mask=self.preserve_sparsity_mask, - ) - comp_logger.set_loss(loss) - - update_offload_parameter(module, "weight", sparsified_weight) - - # self._hessians[module] already deleted by sparsify_weight - del self._num_samples[module] - - @contextlib.contextmanager - def _maybe_onload_hessian(self, module: torch.nn.Module): - if self.offload_hessians: - device = get_execution_device(module) - self._hessians[module] = self._hessians[module].to(device=device) - - yield - - if self.offload_hessians: - if module in self._hessians: # may have been deleted in context - self._hessians[module] = self._hessians[module].to(device="cpu") - - def on_finalize(self, state: State, **kwargs) -> bool: - self.remove_hooks() - self._hessians = dict() - self._num_samples = dict() - self._module_names = dict() - self._module_sparsities = dict() - - return True +__all__ = ["SparseGPTModifier"] diff --git a/src/llmcompressor/modifiers/pruning/__init__.py b/src/llmcompressor/modifiers/pruning/__init__.py index d54a770f1..664215219 100644 --- a/src/llmcompressor/modifiers/pruning/__init__.py +++ b/src/llmcompressor/modifiers/pruning/__init__.py @@ -2,4 +2,5 @@ from .constant import * from .magnitude import * +from .sparsegpt import * from .wanda import * diff --git a/src/llmcompressor/modifiers/pruning/sparsegpt/__init__.py b/src/llmcompressor/modifiers/pruning/sparsegpt/__init__.py new file mode 100644 index 000000000..8bdc93d14 --- /dev/null +++ b/src/llmcompressor/modifiers/pruning/sparsegpt/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/pruning/sparsegpt/base.py b/src/llmcompressor/modifiers/pruning/sparsegpt/base.py new file mode 100644 index 000000000..fb70cc442 --- /dev/null +++ b/src/llmcompressor/modifiers/pruning/sparsegpt/base.py @@ -0,0 +1,163 @@ +import contextlib +from typing import Dict, Optional, Tuple + +import torch +from compressed_tensors.utils import ( + align_module_device, + get_execution_device, + update_offload_parameter, +) +from loguru import logger +from pydantic import PrivateAttr + +from llmcompressor.core import State +from llmcompressor.modifiers import Modifier +from llmcompressor.modifiers.pruning.sparsegpt.sgpt_mixin import SparsityModifierMixin +from llmcompressor.modifiers.pruning.sparsegpt.sgpt_sparsify import ( + accumulate_hessian, + make_empty_hessian, + sparsify_weight, +) +from llmcompressor.utils.metric_logging import CompressionLogger + +__all__ = ["SparseGPTModifier"] + + +class SparseGPTModifier(SparsityModifierMixin, Modifier): + """ + Modifier for applying the one-shot SparseGPT algorithm to a model + + | Sample yaml: + | test_stage: + | modifiers: + | SparseGPTModifier: + | sparsity: 0.5 + | mask_structure: "2:4" + | dampening_frac: 0.001 + | block_size: 128 + | targets: ['Linear'] + | ignore: ['re:.*lm_head'] + + Lifecycle: + - on_initialize + - register_hook(module, calibrate_module, "forward") + - run_sequential / run_layer_sequential / run_basic + - make_empty_hessian + - accumulate_hessian + - on_sequential_batch_end + - sparsify_weight + - on_finalize + - remove_hooks() + + :param sparsity: Sparsity to compress model to + :param sparsity_profile: Can be set to 'owl' to use Outlier Weighed + Layerwise Sparsity (OWL), more information can be found + in the paper https://arxiv.org/pdf/2310.05175 + :param mask_structure: String to define the structure of the mask to apply. + Must be of the form N:M where N, M are integers that define a custom block + shape. Defaults to 0:0 which represents an unstructured mask. + :param owl_m: Number of outliers to use for OWL + :param owl_lmbda: Lambda value to use for OWL + :param block_size: Used to determine number of columns to compress in one pass + :param dampening_frac: Amount of dampening to apply to H, as a fraction of the + diagonal norm + :param preserve_sparsity_mask: Whether or not to preserve the sparsity mask + during when applying sparsegpt, this becomes useful when starting from a + previously pruned model, defaults to False. + :param offload_hessians: Set to True for decreased memory usage but increased + runtime. + :param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__' + to compress every layer in the model. Alias for `targets` + :param targets: list of layer names to compress during OBCQ, or '__ALL__' + to compress every layer in the model. Alias for `sequential_targets` + :param ignore: optional list of module class names or submodule names to not + quantize even if they match a target. Defaults to empty list. + """ + + # modifier arguments + block_size: int = 128 + dampening_frac: Optional[float] = 0.01 + preserve_sparsity_mask: bool = False + offload_hessians: bool = False + + # private variables + _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) + _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) + + def calibrate_module( + self, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + _output: torch.Tensor, + ): + # Assume that the first argument is the input + inp = args[0] + + # Initialize hessian if not present + if module not in self._num_samples: + device = get_execution_device(module) + self._hessians[module] = make_empty_hessian(module, device=device) + self._num_samples[module] = 0 + + # Accumulate hessian with input with optional offloading + with self._maybe_onload_hessian(module): + self._hessians[module], self._num_samples[module] = accumulate_hessian( + inp, + module, + self._hessians[module], + self._num_samples[module], + ) + + def on_sequential_batch_end(self): + """ + Sparsify modules + TODO: implement with event callback + """ + for module in list(self._num_samples.keys()): + name = self._module_names[module] + sparsity = self._module_sparsities[module] + num_samples = self._num_samples[module] + + logger.info(f"Sparsifying {name} using {num_samples} samples") + with ( + torch.no_grad(), + align_module_device(module), + CompressionLogger(module) as comp_logger, + ): + loss, sparsified_weight = sparsify_weight( + module=module, + hessians_dict=self._hessians, + sparsity=sparsity, + prune_n=self._prune_n, + prune_m=self._prune_m, + block_size=self.block_size, + dampening_frac=self.dampening_frac, + preserve_sparsity_mask=self.preserve_sparsity_mask, + ) + comp_logger.set_loss(loss) + + update_offload_parameter(module, "weight", sparsified_weight) + + # self._hessians[module] already deleted by sparsify_weight + del self._num_samples[module] + + @contextlib.contextmanager + def _maybe_onload_hessian(self, module: torch.nn.Module): + if self.offload_hessians: + device = get_execution_device(module) + self._hessians[module] = self._hessians[module].to(device=device) + + yield + + if self.offload_hessians: + if module in self._hessians: # may have been deleted in context + self._hessians[module] = self._hessians[module].to(device="cpu") + + def on_finalize(self, state: State, **kwargs) -> bool: + self.remove_hooks() + self._hessians = dict() + self._num_samples = dict() + self._module_names = dict() + self._module_sparsities = dict() + + return True diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_mixin.py similarity index 100% rename from src/llmcompressor/modifiers/obcq/sgpt_mixin.py rename to src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_mixin.py diff --git a/src/llmcompressor/modifiers/obcq/sgpt_sparsify.py b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_sparsify.py similarity index 100% rename from src/llmcompressor/modifiers/obcq/sgpt_sparsify.py rename to src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_sparsify.py diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index 3b0eb9f58..ff785f82c 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -11,7 +11,7 @@ from llmcompressor.core import State from llmcompressor.modifiers import Modifier -from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin +from llmcompressor.modifiers.pruning.sparsegpt.sgpt_mixin import SparsityModifierMixin from llmcompressor.modifiers.pruning.wanda.wanda_sparsify import ( accumulate_row_scalars, make_empty_row_scalars, diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 65e1c90e0..d32219e45 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -42,7 +42,7 @@ class GPTQModifier(Modifier, HooksMixin): | Sample yaml: | test_stage: - | obcq_modifiers: + | modifiers: | GPTQModifier: | block_size: 128 | dampening_frac: 0.001 diff --git a/src/llmcompressor/pipelines/sequential/README.md b/src/llmcompressor/pipelines/sequential/README.md index 41209f34b..5fc3949fe 100644 --- a/src/llmcompressor/pipelines/sequential/README.md +++ b/src/llmcompressor/pipelines/sequential/README.md @@ -1,7 +1,7 @@ # Sequential Pipeline # The sequential pipeline is a data pipeline, primarily used for compressing models with the [GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py) or the -[SparseGPTModifier](/src/llmcompressor/modifiers/obcq/base.py). +[SparseGPTModifier](/src/llmcompressor/modifiers/pruning/sparsegpt/base.py). If, when using this pipeline, you encounter a `torch.fx.proxy.TraceError`, see the [Model Tracing Guide](/src/llmcompressor/transformers/tracing/GUIDE.md). \ No newline at end of file diff --git a/src/llmcompressor/transformers/finetune/README.md b/src/llmcompressor/transformers/finetune/README.md index 453fb91cb..387da51f1 100644 --- a/src/llmcompressor/transformers/finetune/README.md +++ b/src/llmcompressor/transformers/finetune/README.md @@ -45,7 +45,7 @@ See [configure_fsdp.md](../../../../examples/finetuning/configure_fsdp.md) for a ```python from llmcompressor.transformers import train -model = "./obcq_deployment" +model = "./model_path" teacher_model = "Xenova/llama2.c-stories15M" dataset_name = "open_platypus" concatenate_data = False @@ -74,10 +74,9 @@ train( Finetuning arguments are split up into 3 groups: -* ModelArguments: `src/llmcompressor/transformers/utils/arg_parser/model_arguments.py` -* TrainingArguments: `src/llmcompressor/transformers/utils/arg_parser/training_arguments.py` -* DatasetArguments: `src/llmcompressor/transformers/utils/arg_parser/dataset_arguments.py` -* RecipeArguments: `src/llmcompressor/transformers/utils/arg_parser/recipe_arguments.py` +* ModelArguments: `src/llmcompressor/transformers/finetune/model_args.py` +* TrainingArguments: `src/llmcompressor/transformers/finetune/training_args.py` +* DataTrainingArguments: `src/llmcompressor/transformers/finetune/data/data_training_args.py` ## Running One-Shot with FSDP diff --git a/src/llmcompressor/transformers/tracing/GUIDE.md b/src/llmcompressor/transformers/tracing/GUIDE.md index 2657037eb..27ba9b899 100644 --- a/src/llmcompressor/transformers/tracing/GUIDE.md +++ b/src/llmcompressor/transformers/tracing/GUIDE.md @@ -12,7 +12,7 @@ such as [GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py) ## 1. Why is Tracing Required? ## Due to the memory-intensive nature of some modifiers such as [GPTQModifier](/src/llmcompressor/modifiers/quantization/gptq/base.py) -and [SparseGPTModifier](/src/llmcompressor/modifiers/obcq/base.py), a [Sequential Pipeline](/src/llmcompressor/pipelines/sequential/pipeline.py) +and [SparseGPTModifier](/src/llmcompressor/modifiers/pruning/sparsegpt/base.py), a [Sequential Pipeline](/src/llmcompressor/pipelines/sequential/pipeline.py) is required in order to offload activations and reduce memory usage as well as propagate the activation error induced by compression. diff --git a/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_group-128_recipe.yaml b/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_group-128_recipe.yaml index 92cc85ae7..7523b09a7 100644 --- a/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_group-128_recipe.yaml +++ b/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_group-128_recipe.yaml @@ -4,8 +4,7 @@ sparsity_stage: SparseGPTModifier: sparsity: 0.5 mask_structure: "2:4" - targets: ["Linear"] - ignore: ["re:.*lm_head"] + sequential_update: false quantization_stage: run_type: oneshot quantization_modifiers: diff --git a/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_recipe.yaml b/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_recipe.yaml index dc7e18b6e..b8a4402d8 100644 --- a/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_recipe.yaml +++ b/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_recipe.yaml @@ -4,8 +4,7 @@ sparsity_stage: SparseGPTModifier: sparsity: 0.5 mask_structure: "2:4" - targets: ["Linear"] - ignore: ["re:.*lm_head"] + sequential_update: false quantization_stage: run_type: oneshot quantization_modifiers: diff --git a/tests/e2e/vLLM/test_vllm.py b/tests/e2e/vLLM/test_vllm.py index 6c42f82df..b31bfb007 100644 --- a/tests/e2e/vLLM/test_vllm.py +++ b/tests/e2e/vLLM/test_vllm.py @@ -1,13 +1,12 @@ import os -import re import shutil from pathlib import Path +from typing import Callable import pytest import yaml from huggingface_hub import HfApi from loguru import logger -from parameterized import parameterized_class from llmcompressor.core import active_session from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing @@ -21,24 +20,19 @@ vllm_installed = False logger.warning("vllm is not installed. This test will be skipped") - HF_MODEL_HUB_NAME = "nm-testing" - TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", "") -SKIP_HF_UPLOAD = os.environ.get("SKIP_HF_UPLOAD", "") -EXPECTED_SAVED_FILES = [ - "config.json", - r"^model(?:-\d{5}-of-\d{5})?\.safetensors$", - "recipe.yaml", - "tokenizer.json", -] + +@pytest.fixture +def record_config_file(record_testsuite_property: Callable[[str, object], None]): + test_data_file_name = TEST_DATA_FILE.split("configs/")[-1] + record_testsuite_property("TEST_DATA_FILE_NAME", test_data_file_name) # Will run each test case in its own process through run_tests.sh # emulating vLLM CI testing @requires_gpu_count(1) -@parameterized_class("test_data_file", [(TEST_DATA_FILE,)]) @pytest.mark.skipif(not vllm_installed, reason="vLLM is not installed, skipping test") class TestvLLM: """ @@ -58,9 +52,7 @@ class TestvLLM: """ # noqa: E501 def set_up(self): - eval_config = yaml.safe_load( - Path(self.test_data_file).read_text(encoding="utf-8") - ) + eval_config = yaml.safe_load(Path(TEST_DATA_FILE).read_text(encoding="utf-8")) if os.environ.get("CADENCE", "commit") != eval_config.get("cadence"): pytest.skip("Skipping test; cadence mismatch") @@ -73,7 +65,6 @@ def set_up(self): self.recipe = eval_config.get("recipe") self.quant_type = eval_config.get("quant_type") self.save_dir = eval_config.get("save_dir") - self.save_compressed = eval_config.get("save_compressed", True) logger.info("========== RUNNING ==============") logger.info(self.scheme) @@ -88,6 +79,7 @@ def set_up(self): ] self.api = HfApi() + @pytest.mark.usefixtures("record_config_file") def test_vllm(self): # Run vLLM with saved model import torch @@ -108,19 +100,11 @@ def test_vllm(self): quant_type=self.quant_type, ) - # check that session contains recipe - self._check_session_contains_recipe() - logger.info("================= SAVING TO DISK ======================") - oneshot_model.save_pretrained( - self.save_dir, save_compressed=self.save_compressed - ) + oneshot_model.save_pretrained(self.save_dir) tokenizer.save_pretrained(self.save_dir) recipe_path = os.path.join(self.save_dir, "recipe.yaml") - # check that expected files exist - self._check_save_dir_has_expected_files() - # Use the session to fetch the recipe; # Reset session for next test case session = active_session() @@ -129,22 +113,12 @@ def test_vllm(self): fp.write(recipe_yaml_str) session.reset() - if SKIP_HF_UPLOAD.lower() != "yes": - logger.info("================= UPLOADING TO HUB ======================") + logger.info("================= UPLOADING TO HUB ======================") - stub = f"{HF_MODEL_HUB_NAME}/{self.save_dir}-e2e" - - self.api.create_repo( - repo_id=stub, - exist_ok=True, - repo_type="model", - private=False, - ) - - self.api.upload_folder( - repo_id=stub, - folder_path=self.save_dir, - ) + self.api.upload_folder( + repo_id=f"{HF_MODEL_HUB_NAME}/{self.save_dir}-e2e", + folder_path=self.save_dir, + ) logger.info("================= RUNNING vLLM =========================") @@ -172,35 +146,3 @@ def test_vllm(self): def tear_down(self): if self.save_dir is not None: shutil.rmtree(self.save_dir) - - def _check_session_contains_recipe(self) -> None: - session = active_session() - recipe_yaml_str = session.get_serialized_recipe() - assert recipe_yaml_str is not None - - def _check_save_dir_has_expected_files(self): - files = os.listdir(self.save_dir) - logger.debug("Saved files: ", files) - - matched_patterns = set() - - for expected in EXPECTED_SAVED_FILES: - # Find all files matching the expected pattern - matches = [ - file - for file in files - if ( - re.fullmatch(expected, file) - if expected.startswith("^") - else file == expected - ) - ] - if len(matches) > 0: - matched_patterns.add(expected) - - assert len(matched_patterns) == len(EXPECTED_SAVED_FILES), ( - "expected: ", - EXPECTED_SAVED_FILES, - "\n saved: ", - list(matched_patterns), - ) diff --git a/tests/examples/utils.py b/tests/examples/utils.py index 29eba8dd4..38ff98d64 100644 --- a/tests/examples/utils.py +++ b/tests/examples/utils.py @@ -68,10 +68,7 @@ def copy_and_run_command( def copy_and_run_script( - tmp_path: Path, - example_dir: str, - script_filename: str, - flags: Optional[list[str]] = None, + tmp_path: Path, example_dir: str, script_filename: str ) -> Tuple[List[str], CompletedProcess[str]]: """ Copies the contents of example_dir (relative to the current working directory) to @@ -84,8 +81,6 @@ def copy_and_run_script( :return: subprocess.CompletedProcess object """ command = [sys.executable, script_filename] - if flags: - command.extend(flags) return command, copy_and_run_command(tmp_path, example_dir, command) diff --git a/tests/llmcompressor/modifiers/calibration/test_cache.py b/tests/llmcompressor/modifiers/calibration/test_cache.py index 898c342f5..6ea024037 100644 --- a/tests/llmcompressor/modifiers/calibration/test_cache.py +++ b/tests/llmcompressor/modifiers/calibration/test_cache.py @@ -28,7 +28,7 @@ def test_is_quantized_cache_singleton(): args = QuantizationArgs() cache = QuantizedKVParameterCache(args) - observer = args.observer + observer = args.get_observer() observer = Observer.load_from_registry(observer, quantization_args=args) tensor = torch.tensor([1, 2, 3]) diff --git a/tests/llmcompressor/modifiers/conf.py b/tests/llmcompressor/modifiers/conf.py index 0a910788c..3eab9b85c 100644 --- a/tests/llmcompressor/modifiers/conf.py +++ b/tests/llmcompressor/modifiers/conf.py @@ -1,7 +1,3 @@ -from unittest.mock import MagicMock - -from torch.utils.data import DataLoader - from llmcompressor.core import State from llmcompressor.core.events import EventType from llmcompressor.core.lifecycle import CallbacksEventLifecycle @@ -28,7 +24,7 @@ def __init__( optimizer=optimizer, start=start, steps_per_epoch=1, - calib_data=DataLoader(MagicMock(__len__=lambda _: 0, column_names=[])), + calib_data=[], ) self.event_lifecycle = CallbacksEventLifecycle( diff --git a/tests/llmcompressor/modifiers/pruning/sparsegpt/test_base.py b/tests/llmcompressor/modifiers/pruning/sparsegpt/test_base.py index 2126baa99..e4baccc13 100644 --- a/tests/llmcompressor/modifiers/pruning/sparsegpt/test_base.py +++ b/tests/llmcompressor/modifiers/pruning/sparsegpt/test_base.py @@ -3,7 +3,7 @@ import pytest from llmcompressor.modifiers.factory import ModifierFactory -from llmcompressor.modifiers.obcq.base import SparseGPTModifier +from llmcompressor.modifiers.pruning.sparsegpt.base import SparseGPTModifier from tests.llmcompressor.modifiers.conf import setup_modifier_factory diff --git a/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py b/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py new file mode 100644 index 000000000..203d1fe03 --- /dev/null +++ b/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py @@ -0,0 +1,41 @@ +from collections import OrderedDict + +import torch +from compressed_tensors.quantization.lifecycle.apply import apply_quantization_config +from compressed_tensors.quantization.quant_config import QuantizationConfig +from compressed_tensors.quantization.quant_scheme import preset_name_to_scheme +from loguru import logger + +from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper + + +def test_ignore(): + model = torch.nn.Sequential( + OrderedDict( + [ + ("first_layer", torch.nn.Linear(2, 3)), + ("second_layer", torch.nn.Linear(3, 5)), + ] + ) + ) + + config = QuantizationConfig( + config_groups={"group_0": preset_name_to_scheme("W8A8", targets=["Linear"])}, + ignore=["first_layer"], + ) + apply_quantization_config(model, config) + + messages = [] + logger.add(lambda m: messages.append(m)) + + with torch.no_grad(): + first_compressor = GPTQWrapper("first_layer", model.first_layer) + first_compressor.add_batch(torch.ones(2), None) + first_compressor.compress() + + second_compressor = GPTQWrapper("second_layer", model.second_layer) + second_compressor.add_batch(torch.ones(3), None) + second_compressor.compress() + + assert sum("Skipping unquantized layer first_layer" in m for m in messages) == 1 + assert sum("Skipping unquantized layer second_layer" in m for m in messages) == 0 diff --git a/tests/llmcompressor/modifiers/smoothquant/test_utils.py b/tests/llmcompressor/modifiers/smoothquant/test_utils.py index 457b64cdb..95be6bd30 100644 --- a/tests/llmcompressor/modifiers/smoothquant/test_utils.py +++ b/tests/llmcompressor/modifiers/smoothquant/test_utils.py @@ -12,10 +12,7 @@ @pytest.mark.unit def test_handle_mapping_resolution_errors(): - README_LOCATION = ( - "https://github.com/vllm-project/llm-compressor/tree/main/" - "src/llmcompressor/modifiers/smoothquant" - ) + README_LOCATION = "llmcompressor/modifiers/smoothquant/README.md" @handle_mapping_resolution_errors def func_that_raises_exception(): diff --git a/tests/llmcompressor/modifiers/utils/test_hooks.py b/tests/llmcompressor/modifiers/utils/test_hooks.py index 2a402e980..5c4fc5891 100644 --- a/tests/llmcompressor/modifiers/utils/test_hooks.py +++ b/tests/llmcompressor/modifiers/utils/test_hooks.py @@ -64,27 +64,6 @@ def test_remove_hooks(): assert mod_a.hook_called and not mod_b.hook_called -def test_remove_hooks_parameterized(): - model = DummyModel() - - mod_a = ModA() - mod_a_pre_hook = mod_a.register_hook(model.linear1, mod_a.hook, "forward_pre") - mod_a_post_hook = mod_a.register_hook(model.linear1, mod_a.hook, "forward") - - mod_b = ModB() - mod_b_pre_hook = mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") - mod_b_post_hook = mod_b.register_hook(model.linear2, mod_b.hook, "forward") - - mod_a.remove_hooks(set([mod_a_post_hook])) - mod_b.remove_hooks(set([mod_b_pre_hook])) - - assert len(mod_a._hooks) == 1 and next(iter(mod_a._hooks)) == mod_a_pre_hook - assert len(mod_b._hooks) == 1 and next(iter(mod_b._hooks)) == mod_b_post_hook - - model(model.dummy_inputs) - assert mod_a.hook_called and mod_b.hook_called - - def test_disable_hooks(): model = DummyModel() @@ -102,75 +81,3 @@ def test_disable_hooks(): mod_b.hook_called = False model(model.dummy_inputs) assert mod_a.hook_called and mod_b.hook_called - - -def test_disable_hooks_keep(): - model = DummyModel() - - mod_a = ModA() - handle_a = mod_a.register_hook(model.linear1, mod_a.hook, "forward") - - mod_b = ModB() - handle_b = mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") - - with HooksMixin.disable_hooks(keep=set([handle_b])): - model(model.dummy_inputs) - assert not mod_a.hook_called and mod_b.hook_called - - mod_a.hook_called = False - mod_b.hook_called = False - with HooksMixin.disable_hooks(keep=set([handle_a])): - model(model.dummy_inputs) - assert mod_a.hook_called and not mod_b.hook_called - - mod_a.hook_called = False - mod_b.hook_called = False - model(model.dummy_inputs) - assert mod_a.hook_called and mod_b.hook_called - - -def test_disable_hooks_composable(): - model = DummyModel() - - mod_a = ModA() - handle_a = mod_a.register_hook(model.linear1, mod_a.hook, "forward") - - mod_b = ModB() - handle_b = mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") - - # composing two keeps - with ( - HooksMixin.disable_hooks(keep=set([handle_b])), - HooksMixin.disable_hooks(keep=set([handle_a])), - ): - model(model.dummy_inputs) - assert mod_a.hook_called and mod_b.hook_called - - mod_a.hook_called = False - mod_b.hook_called = False - model(model.dummy_inputs) - assert mod_a.hook_called and mod_b.hook_called - - mod_a.hook_called = False - mod_b.hook_called = False - with HooksMixin.disable_hooks(): - model(model.dummy_inputs) - assert not mod_a.hook_called and not mod_b.hook_called - - # composing a keep and an empty keep - mod_a.hook_called = False - mod_b.hook_called = False - with HooksMixin.disable_hooks(keep=set([handle_a])), HooksMixin.disable_hooks(): - model(model.dummy_inputs) - assert mod_a.hook_called and not mod_b.hook_called - - mod_a.hook_called = False - mod_b.hook_called = False - model(model.dummy_inputs) - assert mod_a.hook_called and mod_b.hook_called - - mod_a.hook_called = False - mod_b.hook_called = False - with HooksMixin.disable_hooks(): - model(model.dummy_inputs) - assert not mod_a.hook_called and not mod_b.hook_called diff --git a/tests/llmcompressor/observers/test_min_max.py b/tests/llmcompressor/observers/test_min_max.py index b592579f6..f23a06dba 100644 --- a/tests/llmcompressor/observers/test_min_max.py +++ b/tests/llmcompressor/observers/test_min_max.py @@ -37,7 +37,7 @@ def test_min_max_observer(symmetric, expected_scale, expected_zero_point): num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric) - observer = weights.observer + observer = weights.get_observer() observer = Observer.load_from_registry(observer, quantization_args=weights) scale, zero_point = observer(tensor) @@ -52,7 +52,7 @@ def test_min_max_observer_symmetric_scale_range(): num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - observer = weights.observer + observer = weights.get_observer() observer = Observer.load_from_registry(observer, quantization_args=weights) scale, zero_point = observer(tensor) @@ -80,7 +80,7 @@ def test_min_max_observer_value_update(): tensor = inp num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - observer = weights.observer + observer = weights.get_observer() observer = Observer.load_from_registry(observer, quantization_args=weights) curr_max = 1 curr_min = 1 @@ -107,7 +107,7 @@ def test_g_idx(): weights = QuantizationArgs(num_bits=8, group_size=group_size) g_idx = make_dummy_g_idx(tensor.shape[1], group_size) - observer = weights.observer + observer = weights.get_observer() observer = Observer.load_from_registry(observer, quantization_args=weights) scale_g_idx, zero_point_g_idx = observer(tensor, g_idx=g_idx) diff --git a/tests/llmcompressor/observers/test_mse.py b/tests/llmcompressor/observers/test_mse.py index 4447813b3..ec2ecf1b5 100644 --- a/tests/llmcompressor/observers/test_mse.py +++ b/tests/llmcompressor/observers/test_mse.py @@ -32,7 +32,7 @@ def test_mse_observer(symmetric, expected_scale, expected_zero_point): num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse") - observer = weights.observer + observer = weights.get_observer() observer = Observer.load_from_registry(observer, quantization_args=weights) scale, zero_point = observer(tensor) @@ -48,7 +48,7 @@ def test_mse_observer_symmetric_scale_range(): num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - observer = weights.observer + observer = weights.get_observer() observer = Observer.load_from_registry(observer, quantization_args=weights) scale, zero_point = observer(tensor) diff --git a/tests/llmcompressor/pytorch/helpers.py b/tests/llmcompressor/pytorch/helpers.py index 341c18f11..d7b52a836 100644 --- a/tests/llmcompressor/pytorch/helpers.py +++ b/tests/llmcompressor/pytorch/helpers.py @@ -1,6 +1,5 @@ from collections import OrderedDict, namedtuple from typing import List -from unittest.mock import Mock import pytest import torch @@ -97,7 +96,6 @@ def __init__(self): ] ) ) - self.config = Mock(use_cache=False) def forward(self, inp: Tensor): return self.seq(inp) diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index c1f0cb425..0752f2a30 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -4,7 +4,7 @@ from compressed_tensors.quantization import QuantizationScheme from parameterized import parameterized -from llmcompressor.modifiers.obcq import SparseGPTModifier +from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier from llmcompressor.modifiers.quantization.gptq import GPTQModifier from llmcompressor.modifiers.quantization.quantization import QuantizationModifier from llmcompressor.utils.pytorch.module import qat_active @@ -29,11 +29,12 @@ def setUp(self): ) def test_invalid_layerwise_recipes_raise_exceptions(self, sparsity, targets): setup_modifier_factory() - modifier = SparseGPTModifier( + kwargs = dict( sparsity=sparsity, block_size=128, targets=targets, ) + modifier = SparseGPTModifier(**kwargs) testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) # confirm invalid layerwise recipes fail at initialization @@ -49,16 +50,16 @@ def setUp(self): def test_successful_layerwise_recipe(self): sparsities = [0.5, 0.2] targets = ["seq.fc1", "seq.fc2"] - modifier = SparseGPTModifier( - sparsity=sparsities, block_size=128, targets=targets - ) - testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) - modifier.initialize(testing_harness.get_state()) + kwargs = dict(sparsity=sparsities, block_size=128, targets=targets) + modifier = SparseGPTModifier(**kwargs) + modifier.compressible_layers_ = {"seq.fc1": None, "seq.fc2": None} + modifier.model = LinearNet() + found_compressible_layers = modifier.compressible_layers() + modifier.compressible_layers_ = found_compressible_layers + modifier._validate_layerwise_sparsity() - model = testing_harness.state.model - num_hooks = len(modifier._hooks) - num_found = sum(len(module._forward_hooks) > 0 for module in model.modules()) - self.assertEqual(num_hooks, num_found) + # ensure layers names successfully match up with model + self.assertEqual(len(found_compressible_layers), len(targets)) @pytest.mark.unit @@ -67,16 +68,18 @@ def setUp(self): setup_modifier_factory() def test_create_default_quant_modifier(self): - modifier = GPTQModifier(block_size=128) - assert modifier._quantization_modifier is None + kwargs = dict(block_size=128) + + modifier = GPTQModifier(**kwargs) + assert modifier.quantization_modifier_ is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - assert isinstance(modifier._quantization_modifier, QuantizationModifier) - modifier._quantization_modifier.create_init_config() + assert isinstance(modifier.quantization_modifier_, QuantizationModifier) + modifier.quantization_modifier_.create_init_config() default_config_group_name = "group_0" - should_be_default_quant_scheme = modifier._quantization_modifier.config_groups[ + should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[ default_config_group_name ] assert should_be_default_quant_scheme.input_activations is None @@ -103,8 +106,9 @@ def test_set_quant_if_modifer_already_exists(self): modifier.initialize(testing_harness.get_state()) assert qat_active(testing_harness.get_state().model) - modifier = GPTQModifier(block_size=128) - assert not modifier._quantization_modifier + kwargs = dict(block_size=128) + modifier = GPTQModifier(**kwargs) + assert not modifier.quantization_modifier_ modifier.on_initialize_structure(testing_harness.get_state()) # since quantization modifier is already applied, quantization must be set in @@ -138,15 +142,17 @@ def setUp(self): self.quant_config = {"QuantizationModifier": self.quant_kwargs} def test_set_quant_in_gptq(self): - modifier = GPTQModifier(block_size=128, quantize=self.quant_config) - assert modifier._quantization_modifier is None + kwargs = dict(block_size=128, quantize=self.quant_config) + + modifier = GPTQModifier(**kwargs) + assert modifier.quantization_modifier_ is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - self.assertIsInstance(modifier._quantization_modifier, QuantizationModifier) + self.assertIsInstance(modifier.quantization_modifier_, QuantizationModifier) - dict_scheme = dict(modifier._quantization_modifier.config_groups) + dict_scheme = dict(modifier.quantization_modifier_.config_groups) self._check_config( dict(dict_scheme["config_group_0"].weights), self.quant_kwargs["config_groups"]["config_group_0"]["weights"], diff --git a/tests/llmcompressor/recipe/test_recipe.py b/tests/llmcompressor/recipe/test_recipe.py index 7a3674052..729543eee 100644 --- a/tests/llmcompressor/recipe/test_recipe.py +++ b/tests/llmcompressor/recipe/test_recipe.py @@ -4,7 +4,7 @@ import yaml from llmcompressor.modifiers import Modifier -from llmcompressor.modifiers.obcq.base import SparseGPTModifier +from llmcompressor.modifiers.pruning.sparsegpt.base import SparseGPTModifier from llmcompressor.recipe import Recipe from llmcompressor.recipe.recipe import create_recipe_string_from_modifiers from tests.llmcompressor.helpers import valid_recipe_strings diff --git a/tests/llmcompressor/transformers/compression/recipes/sparse_24.yaml b/tests/llmcompressor/transformers/compression/recipes/sparse_24.yaml index cd1280eb1..060d0d06c 100644 --- a/tests/llmcompressor/transformers/compression/recipes/sparse_24.yaml +++ b/tests/llmcompressor/transformers/compression/recipes/sparse_24.yaml @@ -1,5 +1,5 @@ pruning_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 mask_structure: "2:4" diff --git a/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml b/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml index 0d3c8bad5..75276f842 100644 --- a/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml +++ b/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml @@ -1,5 +1,5 @@ pruning_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 mask_structure: "2:4" diff --git a/tests/llmcompressor/transformers/compression/run_compressed_configs/fp8_dynamic.yaml b/tests/llmcompressor/transformers/compression/run_compressed_configs/fp8_dynamic.yaml index 926c31ec3..d516616bf 100644 --- a/tests/llmcompressor/transformers/compression/run_compressed_configs/fp8_dynamic.yaml +++ b/tests/llmcompressor/transformers/compression/run_compressed_configs/fp8_dynamic.yaml @@ -1,4 +1,4 @@ cadence: "commit" test_type: "regression" -compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-Dynamic-compressed -uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-Dynamic-uncompressed \ No newline at end of file +model_stub: "nm-testing/tinyllama-fp8-dynamic-compressed" +empty_model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/run_compressed_configs/w4a16.yaml b/tests/llmcompressor/transformers/compression/run_compressed_configs/w4a16.yaml index 51d9ec25b..7e9bc3f2f 100644 --- a/tests/llmcompressor/transformers/compression/run_compressed_configs/w4a16.yaml +++ b/tests/llmcompressor/transformers/compression/run_compressed_configs/w4a16.yaml @@ -1,4 +1,4 @@ cadence: "commit" test_type: "regression" -compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-compressed -uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-uncompressed \ No newline at end of file +model_stub: "nm-testing/tinyllama-w4a16-compressed" +empty_model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a16_dense.yaml b/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a16_dense.yaml new file mode 100644 index 000000000..af1e5df8b --- /dev/null +++ b/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a16_dense.yaml @@ -0,0 +1,4 @@ +cadence: "commit" +test_type: "regression" +model_stub: "nm-testing/tinyllama-w8a16-dense" +empty_model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a8.yaml b/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a8.yaml index 3c1646b16..086a67ed6 100644 --- a/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a8.yaml +++ b/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a8.yaml @@ -1,4 +1,4 @@ cadence: "commit" test_type: "regression" -compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A8-Dynamic-Per-Token-compressed -uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A8-Dynamic-Per-Token-uncompressed \ No newline at end of file +model_stub: "nm-testing/tinyllama-w8a8-compressed" +empty_model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/test_infer_quant_format.py b/tests/llmcompressor/transformers/compression/test_infer_quant_format.py index 1eb3bf202..7db2f0687 100644 --- a/tests/llmcompressor/transformers/compression/test_infer_quant_format.py +++ b/tests/llmcompressor/transformers/compression/test_infer_quant_format.py @@ -1,4 +1,5 @@ import pytest +from compressed_tensors.config import SparsityCompressionConfig from compressed_tensors.quantization import preset_name_to_scheme from llmcompressor.transformers.compression.quantization_format import ( @@ -19,6 +20,9 @@ ], ) def test_infer_quant_format(preset, sparsity_structure, expected_format): + sparsity_config = SparsityCompressionConfig( + format="dense", sparsity_structure=sparsity_structure + ) quant_scheme = preset_name_to_scheme(preset, targets=["Linear"]) dummy_model = LinearNet() @@ -26,6 +30,6 @@ def test_infer_quant_format(preset, sparsity_structure, expected_format): module.quantization_scheme = quant_scheme inferred_format = infer_quantization_format( - dummy_model, save_compressed=True, sparsity_structure=sparsity_structure + dummy_model, save_compressed=True, sparsity_config=sparsity_config ) assert inferred_format.value == expected_format diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py index 0d34d1ca0..13eab66c9 100644 --- a/tests/llmcompressor/transformers/compression/test_quantization.py +++ b/tests/llmcompressor/transformers/compression/test_quantization.py @@ -10,10 +10,10 @@ from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator -from llmcompressor.args import DatasetArguments from llmcompressor.pytorch.utils import tensors_to_device from llmcompressor.transformers import oneshot from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/compression/configs" @@ -74,6 +74,7 @@ def _run_oneshot(model, recipe, dataset, output_dir): ) from llmcompressor.pytorch.model_load.helpers import get_session_model + # note: get_session_model() is None outside of function scope return get_session_model() def _get_quant_info(self, model): @@ -146,7 +147,7 @@ def _get_dataloader(self, data_args, tokenizer): @torch.no_grad() def test_perplexity(self): tokenizer = AutoTokenizer.from_pretrained(self.model_stub) - data_args = DatasetArguments( + data_args = DataTrainingArguments( dataset="ultrachat-200k", max_seq_length=self.max_seq_length, ) diff --git a/tests/llmcompressor/transformers/compression/test_run_compressed.py b/tests/llmcompressor/transformers/compression/test_run_compressed.py index 616dd0dfe..0c2a0ab0e 100644 --- a/tests/llmcompressor/transformers/compression/test_run_compressed.py +++ b/tests/llmcompressor/transformers/compression/test_run_compressed.py @@ -1,133 +1,79 @@ -import copy import shutil import tempfile import unittest +import torch from compressed_tensors import QUANTIZATION_CONFIG_NAME from compressed_tensors.compressors import ModelCompressor from compressed_tensors.quantization import QuantizationStatus from parameterized import parameterized_class from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -from transformers.utils.quantization_config import CompressedTensorsConfig from tests.testing_utils import parse_params, requires_gpu -CONFIG_DIR = "tests/llmcompressor/transformers/compression/decompression_configs" +CONFIG_DIR = "tests/llmcompressor/transformers/compression/run_compressed_configs" @requires_gpu @parameterized_class(parse_params(CONFIG_DIR)) -class TestDecompression(unittest.TestCase): - """ - Check that HFQuantizer decompression is working as expected. - Manually decompress a compressed model and compare the generations - - Decompression: - Given a skeleton model and path to the optimized model, - write the optimized model's safetensors to the skeleton model and decompress - Ex. write weight_scale to the skeleton model and then convert from fp4 to fp16 - - """ - - compressed_model_stub = None - skeleton_model_stub = None - - SAMPLE_INPUTS = [ - "I love 4-bit quantization because", - "What is the capital of France?", - "def fibonacci(n):", - ] +class TestQuantizationMatches(unittest.TestCase): + model_stub = None + empty_model = None @classmethod - def setUpClass(self): - self.test_dir = tempfile.mkdtemp() - self.tokenizer = AutoTokenizer.from_pretrained(self.compressed_model_stub) + def setUpClass(cls): + cls.test_dir = tempfile.mkdtemp() - # Decompress using HFQuantizer from AutoModelForCausalLM - self.decompressed_model_hf_quantizer = AutoModelForCausalLM.from_pretrained( - self.compressed_model_stub, + # TODO: Give option on HFQuantizer to run run_compressed True/False + # currently hardcoded to True + cls.compressed_model = AutoModelForCausalLM.from_pretrained( + cls.model_stub, torch_dtype="auto", device_map="auto", - quantization_config=CompressedTensorsConfig(run_compressed=False), + # run_compressed=True, # TODO: Give option on HFQuantizer ) - - # Manually decompress this model - self.dense_model = AutoModelForCausalLM.from_pretrained( - self.skeleton_model_stub, - torch_dtype=self.decompressed_model_hf_quantizer.dtype, - device_map=self.decompressed_model_hf_quantizer.device, - ) - - # decompression from HFQuantizer should populate weight_scale - assert hasattr( - self.decompressed_model_hf_quantizer.model.layers[0].self_attn.q_proj, - "weight_scale", - ) - - # dense model should not have weight_scale populated - assert not hasattr( - self.dense_model.model.layers[0].self_attn.q_proj, "weight_scale" + # TODO: Use ModelCompressor until decompression is supported through + # HFQuant/run_compressed can be turned off. + cls.uncompressed_model = AutoModelForCausalLM.from_pretrained( + cls.empty_model, + torch_dtype=cls.compressed_model.dtype, + device_map=cls.compressed_model.device, ) - - config = AutoConfig.from_pretrained(self.compressed_model_stub) - + config = AutoConfig.from_pretrained(cls.model_stub) compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) - self.compressor = ModelCompressor.from_compression_config(compression_config) - self.compressor.quantization_config.quantization_status = ( + cls.compressor = ModelCompressor.from_compression_config(compression_config) + cls.compressor.quantization_config.quantization_status = ( QuantizationStatus.FROZEN ) - - # use the model_path to load the decompressed weights into dense_model - dense_model = copy.deepcopy(self.dense_model) - - # overwrite the weights of the dense model - self.compressor.decompress( - model_path=self.compressed_model_stub, - model=self.dense_model, + cls.compressor.decompress( + model_path=cls.model_stub, model=cls.uncompressed_model ) - # self.dense_model should be decompressed - assert dense_model is not self.dense_model + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_stub) - self.decompressed_model_manual = self.dense_model + def test_compressed_matches_uncompressed(self): + SAMPLE_INPUT = [ + "I love 4-bit quantization because", + "What is the capital of France?", + "def fibonacci(n):", + ] - assert hasattr( - self.decompressed_model_manual.model.layers[0].self_attn.q_proj, - "weight_scale", + inputs = self.tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to( + self.compressed_model.device ) - - def test_hf_quantizer_decompress_match_manual_decompress(self): - manual_device = self.decompressed_model_manual.device - decompressed_model_hf_quantizer = self.decompressed_model_hf_quantizer.device - - self.decompressed_model_manual = self.decompressed_model_manual.to( - manual_device + compressed_output = self.tokenizer.batch_decode( + self.compressed_model.generate(**inputs, max_length=50) ) - self.decompressed_model_hf_quantizer = self.decompressed_model_hf_quantizer.to( - decompressed_model_hf_quantizer + uncompressed_output = self.tokenizer.batch_decode( + self.uncompressed_model.generate(**inputs, max_length=50) ) - for input in self.SAMPLE_INPUTS: - inputs = self.tokenizer(input, return_tensors="pt", padding=True).to( - self.decompressed_model_manual.device - ) - inputs = inputs.to(self.decompressed_model_manual.device) - - decompressed_model_manual_output = self.tokenizer.batch_decode( - self.decompressed_model_manual.generate(**inputs, max_length=50) - ) - - decompressed_model_hf_quantizer_out = self.tokenizer.batch_decode( - self.decompressed_model_hf_quantizer.generate(**inputs, max_length=50) - ) - - assert ( - decompressed_model_hf_quantizer_out == decompressed_model_manual_output - ) + for idx in range(len(SAMPLE_INPUT)): + assert compressed_output[idx] == uncompressed_output[idx] @classmethod - def tearDownClass(self): - shutil.rmtree(self.test_dir) - del self.dense_model - del self.decompressed_model_hf_quantizer - del self.decompressed_model_manual + def tearDownClass(cls): + shutil.rmtree(cls.test_dir) + del cls.compressed_model + del cls.uncompressed_model + torch.cuda.empty_cache() diff --git a/tests/llmcompressor/transformers/finetune/data/conftest.py b/tests/llmcompressor/transformers/finetune/data/conftest.py index aa2f056bc..a7a347d99 100644 --- a/tests/llmcompressor/transformers/finetune/data/conftest.py +++ b/tests/llmcompressor/transformers/finetune/data/conftest.py @@ -1,7 +1,7 @@ import pytest from transformers import AutoTokenizer -from llmcompressor.args import ModelArguments +from llmcompressor.transformers.finetune.model_args import ModelArguments @pytest.fixture diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 7eb74f9f9..812b26a56 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,6 +1,6 @@ import pytest -from llmcompressor.args import DatasetArguments +from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( get_raw_dataset, make_dataset_splits, @@ -9,7 +9,7 @@ @pytest.mark.unit def test_combined_datasets(): - data_args = DatasetArguments( + data_args = DataTrainingArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) raw_wikitext2 = get_raw_dataset(data_args) @@ -33,7 +33,7 @@ def test_combined_datasets(): @pytest.mark.unit def test_separate_datasets(): splits = {"train": "train[:10%]", "validation": "train[10%:20%]"} - data_args = DatasetArguments( + data_args = DataTrainingArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) datasets = {} diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index dcc602877..64514b252 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -5,13 +5,12 @@ from datasets import IterableDataset, load_dataset from parameterized import parameterized -from llmcompressor.args import ( - DatasetArguments, +from llmcompressor.transformers import ( + DataTrainingArguments, ModelArguments, - RecipeArguments, + TextGenerationDataset, TrainingArguments, ) -from llmcompressor.transformers import TextGenerationDataset from llmcompressor.transformers.finetune.data.data_helpers import ( format_calibration_data, ) @@ -21,7 +20,7 @@ @pytest.mark.unit class TestConcentrationTokenization(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.data_args = DataTrainingArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -54,7 +53,7 @@ def test_concatenation_tokenization(self): @pytest.mark.unit class TestNoPaddingTokenization(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.data_args = DataTrainingArguments( dataset="open_platypus", pad_to_max_length=False ) @@ -97,7 +96,9 @@ def test_no_padding_tokenization(self): @pytest.mark.unit class TestMaxSeqLenClipped(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments(dataset="open_platypus", max_seq_length=4096) + self.data_args = DataTrainingArguments( + dataset="open_platypus", max_seq_length=4096 + ) @pytest.fixture(autouse=True) def prepare_fixture(self, tiny_llama_tokenizer): @@ -119,7 +120,7 @@ def test_max_seq_len_clipped(self): @pytest.mark.unit class TestDatasetKwargsAndPercent(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.data_args = DataTrainingArguments( dataset="wikitext", raw_kwargs={ "data_files": { @@ -166,7 +167,7 @@ def prepare_fixture(self, tiny_llama_tokenizer): ] ) def test_datasets(self, dataset_key, dataset_config, split, do_concat): - data_args = DatasetArguments( + data_args = DataTrainingArguments( dataset=dataset_key, dataset_config_name=dataset_config, concatenate_data=do_concat, @@ -205,7 +206,7 @@ def prepare_fixture(self, tiny_llama_tokenizer): self.tiny_llama_tokenizer = tiny_llama_tokenizer def setUp(self): - self.data_args = DatasetArguments( + self.data_args = DataTrainingArguments( dataset="evolcodealpaca", dataset_config_name=None, concatenate_data=False, @@ -234,7 +235,7 @@ def test_evol(self): @pytest.mark.unit class TestStreamLoading(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.data_args = DataTrainingArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -275,19 +276,15 @@ def prepare_fixture(self, tiny_llama_tokenizer): [["train"], ["train[60%:]"], [{"train": "train[:20%]"}], [None]] ) def test_split_loading(self, split_def): - data_args = DatasetArguments( + data_args = DataTrainingArguments( dataset="open_platypus", splits=split_def, trust_remote_code_data=True, ) training_args = TrainingArguments(do_train=True, output_dir="dummy") model_args = ModelArguments(model=None) - recipe_args = RecipeArguments() stage_runner = StageRunner( - model_args=model_args, - data_args=data_args, - training_args=training_args, - recipe_args=recipe_args, + model_args=model_args, data_args=data_args, training_args=training_args ) stage_runner.populate_datasets(processor=self.tiny_llama_tokenizer) @@ -321,11 +318,10 @@ def preprocess(sample): ) stage_runner = StageRunner( model_args=None, - data_args=DatasetArguments( + data_args=DataTrainingArguments( dataset=tokenized_dataset, shuffle_calibration_samples=False ), training_args=TrainingArguments(do_oneshot=True), - recipe_args=RecipeArguments(), ) stage_runner.populate_datasets(processor=None) calib_dataset = stage_runner.get_dataset_split("calibration") diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index 694a9b6d3..9aee4c20f 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -1,17 +1,17 @@ import pytest -from llmcompressor.args import DatasetArguments from llmcompressor.transformers.finetune.data import ( C4Dataset, OpenPlatypusDataset, TextGenerationDataset, WikiTextDataset, ) +from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_c4_initializes(tiny_llama_tokenizer): - data_args = DatasetArguments(dataset="c4", concatenate_data=True) + data_args = DataTrainingArguments(dataset="c4", concatenate_data=True) c4_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, @@ -27,7 +27,7 @@ def test_c4_initializes(tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_wikitext_initializes(tiny_llama_tokenizer): - data_args = DatasetArguments( + data_args = DataTrainingArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) wiki_manager = TextGenerationDataset.load_from_registry( @@ -45,7 +45,7 @@ def test_wikitext_initializes(tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_open_platypus_initializes(tiny_llama_tokenizer): - data_args = DatasetArguments(dataset="open_platypus", pad_to_max_length=False) + data_args = DataTrainingArguments(dataset="open_platypus", pad_to_max_length=False) op_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, diff --git a/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml b/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml index b5ef20aca..4f9d4293d 100644 --- a/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml +++ b/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml @@ -1,12 +1,12 @@ test_oneshot_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.7 block_size: 128 + sequential_update: False percdamp: 0.01 mask_structure: "0:0" - targets: ["Linear"] - ignore: ["re:.*lm_head"] + target_ids: ["attention_mask", "position_ids"] test_train_stage: pruning_modifiers: ConstantPruningModifier: diff --git a/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py b/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py index ec517e2d6..5eb49009b 100644 --- a/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py +++ b/tests/llmcompressor/transformers/finetune/test_finetune_oneshot_with_modifier.py @@ -21,7 +21,7 @@ def setUp(self): self.output = Path("./finetune_output") def test_oneshot_with_modifier_object(self): - from llmcompressor.modifiers.obcq.base import SparseGPTModifier + from llmcompressor.modifiers.pruning.sparsegpt.base import SparseGPTModifier from llmcompressor.transformers import oneshot recipe_str = [ diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py index 76ea21706..e9c3d7c5c 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py @@ -1,23 +1,28 @@ +import os import shutil import unittest from pathlib import Path import pytest -from transformers import AutoModelForCausalLM -from transformers.utils.quantization_config import CompressedTensorsConfig - -from llmcompressor.core import create_session -from llmcompressor.modifiers.quantization import QuantizationModifier -from llmcompressor.transformers import oneshot, train @pytest.mark.unit +@pytest.mark.skipif( + "CADENCE" in os.environ + and (os.environ["CADENCE"] == "weekly" or os.environ["CADENCE"] == "nightly"), + reason="Don't run for weekly and nightly tests as those use multi gpu " + "runners and this test fails when ngpu>1", +) class TestOneshotThenFinetune(unittest.TestCase): def setUp(self): self.output = Path("./finetune_output") - self.quantization_config = CompressedTensorsConfig(run_compressed=False) - def test_oneshot_sparsification_then_finetune(self): + def test_oneshot_then_finetune(self): + from transformers import AutoModelForCausalLM + + from llmcompressor.core import create_session + from llmcompressor.transformers import oneshot, train + recipe_str = "tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml" model = AutoModelForCausalLM.from_pretrained( "Xenova/llama2.c-stories15M", device_map="auto" @@ -42,12 +47,8 @@ def test_oneshot_sparsification_then_finetune(self): recipe_str = ( "tests/llmcompressor/transformers/finetune/test_finetune_recipe.yaml" ) - - # Explictly decompress the model for training using quantization_config model = AutoModelForCausalLM.from_pretrained( - self.output / "oneshot_out", - device_map="auto", - quantization_config=self.quantization_config, + self.output / "oneshot_out", device_map="auto" ) distill_teacher = AutoModelForCausalLM.from_pretrained( "Xenova/llama2.c-stories15M", device_map="auto" @@ -72,12 +73,7 @@ def test_oneshot_sparsification_then_finetune(self): ) # test reloading checkpoint and final model - # verify checkpoint reloading and can carry out finetune - # with the saved model - # Explictly decompress the model for training using quantization_config - model = AutoModelForCausalLM.from_pretrained( - output_dir, device_map="auto", quantization_config=self.quantization_config - ) + model = AutoModelForCausalLM.from_pretrained(output_dir, device_map="auto") with create_session(): train( model=model, @@ -92,71 +88,5 @@ def test_oneshot_sparsification_then_finetune(self): resume_from_checkpoint=True, # use last checkpoint ) - def test_oneshot_quantization_then_finetune(self): - recipe = QuantizationModifier( - targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"] - ) - - model = AutoModelForCausalLM.from_pretrained( - "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - device_map="auto", - ) - dataset = "open_platypus" - concatenate_data = False - num_calibration_samples = 64 - output_dir = self.output / "oneshot_out" - splits = {"calibration": "train[:10%]"} - - with create_session(): - oneshot( - model=model, - dataset=dataset, - output_dir=output_dir, - num_calibration_samples=num_calibration_samples, - recipe=recipe, - concatenate_data=concatenate_data, - splits=splits, - ) - - from transformers.utils.quantization_config import CompressedTensorsConfig - - quantization_config = CompressedTensorsConfig(run_compressed=False) - model = AutoModelForCausalLM.from_pretrained( - output_dir, - device_map="auto", - quantization_config=quantization_config, - ) - dataset = "open_platypus" - concatenate_data = False - output_dir = self.output / "finetune_out" - splits = {"calibration": "train[:10%]", "train": "train[:10%]"} - - with create_session(): - train( - model=model, - dataset=dataset, - output_dir=output_dir, - num_calibration_samples=num_calibration_samples, - recipe=recipe, - concatenate_data=concatenate_data, - splits=splits, - ) - - # test reloading checkpoint and final model - model = AutoModelForCausalLM.from_pretrained( - output_dir, device_map="auto", quantization_config=quantization_config - ) - with create_session(): - train( - model=model, - dataset=dataset, - output_dir=output_dir, - num_calibration_samples=num_calibration_samples, - recipe=recipe, - concatenate_data=concatenate_data, - splits=splits, - resume_from_checkpoint=True, # use last checkpoint - ) - def tearDown(self): shutil.rmtree(self.output) diff --git a/tests/llmcompressor/transformers/finetune/test_session_mixin.py b/tests/llmcompressor/transformers/finetune/test_session_mixin.py index 93bd74cd1..69a9acd44 100644 --- a/tests/llmcompressor/transformers/finetune/test_session_mixin.py +++ b/tests/llmcompressor/transformers/finetune/test_session_mixin.py @@ -14,8 +14,6 @@ def __init__( model: Module, recipe: Optional[str], recipe_args: Optional[Union[Dict[str, Any], str]] = None, - model_args: Optional[Union[Dict[str, Any], str]] = None, - data_args: Optional[Union[Dict[str, Any], str]] = None, teacher: Optional[Union[Module, str]] = None, **kwargs, ): @@ -23,8 +21,6 @@ def __init__( model=model, recipe=recipe, recipe_args=recipe_args, - model_args=model_args, - data_args=data_args, teacher=teacher, **kwargs, ) diff --git a/tests/llmcompressor/transformers/gptq/test_oneshot.py b/tests/llmcompressor/transformers/gptq/test_oneshot.py index c391890b2..7f1a1ec99 100644 --- a/tests/llmcompressor/transformers/gptq/test_oneshot.py +++ b/tests/llmcompressor/transformers/gptq/test_oneshot.py @@ -74,8 +74,8 @@ def test_oneshot_application(self): oneshot( model=self.model, dataset=self.dataset, - overwrite_output_dir=True, output_dir=self.output, + overwrite_output_dir=True, recipe=self.recipe, oneshot_device=self.device, num_calibration_samples=9, diff --git a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml index 474b021b3..64ce30250 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml @@ -1,8 +1,9 @@ test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.7 block_size: 128 + sequential_update: True percdamp: 0.01 mask_structure: "0:0" targets: ["model.layers.0"] diff --git a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml index afd2f045c..027c56363 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml @@ -1,5 +1,5 @@ test_stage: - obcq_modifiers: + modifiers: QuantizationModifier: config_groups: group_0: @@ -11,6 +11,7 @@ test_stage: SparseGPTModifier: sparsity: 0.7 block_size: 128 + sequential_update: False percdamp: 0.01 mask_structure: "0:0" targets: [ diff --git a/tests/llmcompressor/transformers/obcq/recipes/quant.yaml b/tests/llmcompressor/transformers/obcq/recipes/quant.yaml index 435503e50..1df51c804 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/quant.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/quant.yaml @@ -1,5 +1,5 @@ test_stage: - obcq_modifiers: + modifiers: SmoothQuantModifier: smoothing_strength: 0.6 GPTQModifier: diff --git a/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml b/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml index 0e738a943..eb02ea81d 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml @@ -1,5 +1,5 @@ test_stage: - obcq_modifiers: + modifiers: SmoothQuantModifier: smoothing_strength: 0.5 mappings: [ @@ -18,6 +18,7 @@ test_stage: SparseGPTModifier: sparsity: 0.5 block_size: 128 + sequential_update: False percdamp: 0.01 mask_structure: "0:0" targets: ["model.layers.0"] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml b/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml index 2a727caef..d485064fa 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml @@ -1,8 +1,9 @@ test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.3 block_size: 128 + sequential_update: False percdamp: 0.01 targets: ["model.layers.0", "model.layers.1"] mask_structure: "0:0" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml b/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml index 980fb4173..20c4c9397 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml @@ -1,8 +1,9 @@ test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 + sequential_update: False percdamp: 0.01 mask_structure: "2:4" targets: [ diff --git a/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml b/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml index 05487b104..8a97ff733 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml @@ -1,8 +1,9 @@ test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 + sequential_update: False percdamp: 0.01 mask_structure: "0:0" targets: [ diff --git a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py index 16c9003be..2f6c51ebb 100644 --- a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py +++ b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py @@ -5,11 +5,7 @@ import pytest import yaml from parameterized import parameterized_class -from transformers import AutoModelForCausalLM -from transformers.utils.quantization_config import CompressedTensorsConfig -from llmcompressor.transformers.utils import is_model_ct_quantized_from_path -from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/obcq/obcq_configs/consec_runs" @@ -19,15 +15,13 @@ class TestConsecutiveRuns(unittest.TestCase): - quantization_config = CompressedTensorsConfig(run_compressed=False) - def _test_consecutive_runs( self, tolerance: float, num_calibration_samples: int = 16 ): import math from llmcompressor.core import active_session - from llmcompressor.pytorch.model_load.helpers import initialize_recipe + from llmcompressor.pytorch.model_load.helpers import get_session_model from llmcompressor.pytorch.utils.helpers import tensor_sparsity from llmcompressor.transformers import oneshot from llmcompressor.utils.pytorch import qat_active @@ -42,18 +36,12 @@ def _test_consecutive_runs( oneshot_device=self.device, clear_sparse_session=False, ) - - first_model = AutoModelForCausalLM.from_pretrained( - self.output_first, - device_map="auto", - quantization_config=self.quantization_config, - ) - + first_tiny_model = get_session_model() layer_0_sparse = tensor_sparsity( - first_model.model.layers[0].self_attn.k_proj.weight + first_tiny_model.model.layers[0].self_attn.k_proj.weight ) assert math.isclose(layer_0_sparse.item(), 0.5, rel_tol=tolerance) - assert qat_active(first_model) + assert qat_active(first_tiny_model) session = active_session() session_recipe = session.lifecycle.recipe_container.compiled_recipe @@ -61,10 +49,6 @@ def _test_consecutive_runs( self.assertEqual(len(stages), 1) session.reset() - recipe = infer_recipe_from_model_path(model_path=self.output_first) - if recipe: - initialize_recipe(model=first_model, recipe_path=recipe) - # reload saved model and up sparsity to 0.7 oneshot( model=self.output_first, @@ -73,19 +57,15 @@ def _test_consecutive_runs( recipe=self.second_recipe, output_dir=self.output_second, oneshot_device=self.device, + clear_sparse_session=False, ) - second_model = AutoModelForCausalLM.from_pretrained( - self.output_second, - device_map="auto", - quantization_config=self.quantization_config, - ) - + second_tiny_model = get_session_model() layer_0_sparse = tensor_sparsity( - second_model.model.layers[0].self_attn.k_proj.weight + second_tiny_model.model.layers[0].self_attn.k_proj.weight ) assert math.isclose(layer_0_sparse.item(), 0.7, rel_tol=tolerance) - assert qat_active(second_model) + assert qat_active(second_tiny_model) session = active_session() session_recipe = session.lifecycle.recipe_container.compiled_recipe @@ -138,14 +118,8 @@ class TestConsecutiveRunsGPU(TestConsecutiveRuns): def setUp(self): from transformers import AutoModelForCausalLM - self.assertFalse( - is_model_ct_quantized_from_path(self.model), - "The provided model is quantized. Please use a dense model.", - ) - self.model = AutoModelForCausalLM.from_pretrained( - self.model, - device_map=self.device, + self.model, device_map=self.device ) self.output = "./oneshot_output" diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py index e4974a956..fe699570a 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py @@ -22,11 +22,13 @@ def labeled_dataloader(self, dataset_name, model_name): from torch.utils.data import DataLoader from transformers import AutoTokenizer, DefaultDataCollator - from llmcompressor.args import DatasetArguments from llmcompressor.transformers.finetune.data import TextGenerationDataset + from llmcompressor.transformers.finetune.data.data_args import ( + DataTrainingArguments, + ) tokenizer = AutoTokenizer.from_pretrained(model_name) - data_args = DatasetArguments( + data_args = DataTrainingArguments( dataset=dataset_name, max_seq_length=512, pad_to_max_length=False, diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py b/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py index fef5ebc37..483a65f2d 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py @@ -1,15 +1,27 @@ +import unittest + import pytest -from accelerate import init_empty_weights -from transformers import AutoModelForCausalLM -from llmcompressor.modifiers.obcq import SparseGPTModifier +from llmcompressor.utils.pytorch.module import get_no_split_params @pytest.mark.integration -def test_infer_targets(): - modifier = SparseGPTModifier(sparsity=0.0) - with init_empty_weights(): +class TestInferTargets(unittest.TestCase): + def setUp(self): + from transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained("Xenova/llama2.c-stories15M") + self.modifiable_model = model + self.targets = get_no_split_params(self.modifiable_model) + + def test_infer_targets(self): + from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier + + self.assertEqual(len(self.targets), 1) + self.assertEqual(self.targets[0], "LlamaDecoderLayer") - inferred = modifier._infer_sequential_targets(model) - assert inferred == ["LlamaDecoderLayer"] + modifier = SparseGPTModifier(sparsity=0.5) + modifier.targets = self.targets + modifier.model = self.modifiable_model + compressible_layers = modifier.compressible_layers() + self.assertEqual(len(compressible_layers), 6) diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py b/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py index 4ddf36a51..ddb6f41ff 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py @@ -1,11 +1,7 @@ import unittest -from unittest.mock import MagicMock import pytest -from llmcompressor.core.state import State -from llmcompressor.modifiers.obcq import SparseGPTModifier - @pytest.mark.integration class TestLMHead(unittest.TestCase): @@ -18,7 +14,6 @@ def setUp(self): self.model = AutoModelForCausalLM.from_pretrained( "Xenova/llama2.c-stories15M", device_map=self.device ) - self.kwargs = { "sparsity": 0.5, "block_size": 128, @@ -33,31 +28,21 @@ def setUp(self): ], } - dataset = MagicMock() - dataset.column_names = [] - self.dataloader = MagicMock() - self.dataloader.dataset = dataset - self.dataloader.__iter__.return_value = iter([]) + def test_lm_head_target(self): + from llmcompressor.core.state import State + from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier - def test_no_lm_head_target(self): - modifier = SparseGPTModifier(**self.kwargs) + sparsegpt_modifier_no_head = SparseGPTModifier(**self.kwargs) state = State() - state.update(model=self.model, device=self.device, calib_data=self.dataloader) - modifier.on_initialize(state) - - assert len(self.model.lm_head._forward_hooks) <= 0 - - modifier.finalize(state) + state.update(model=self.model, device=self.device) + sparsegpt_modifier_no_head.initialize_compression(state.model) - def test_lm_head_target(self): self.kwargs["targets"].append("lm_head") - modifier = SparseGPTModifier(**self.kwargs) - - state = State() - state.update(model=self.model, device=self.device, calib_data=self.dataloader) - modifier.on_initialize(state) - - assert len(self.model.lm_head._forward_hooks) == 1 + sparsegpt_modifier_head = SparseGPTModifier(**self.kwargs) + sparsegpt_modifier_head.initialize_compression(state.model) - modifier.finalize(state) + # check we pick up the lm_head layer + layers_no_head = len(sparsegpt_modifier_no_head.compressible_layers_) + layers_head = len(sparsegpt_modifier_head.compressible_layers_) + self.assertEqual(layers_head, layers_no_head + 1) diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_owl.py b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py index 4948c6da3..060032bc9 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_owl.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py @@ -3,7 +3,7 @@ from datasets import Dataset from transformers import AutoModelForCausalLM -from llmcompressor.modifiers.obcq import SparseGPTModifier +from llmcompressor.modifiers.pruning import SparseGPTModifier from llmcompressor.transformers.finetune.data.data_helpers import ( format_calibration_data, ) diff --git a/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py b/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py new file mode 100644 index 000000000..d4d0ba280 --- /dev/null +++ b/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py @@ -0,0 +1,23 @@ +import unittest + +import pytest + + +@pytest.mark.integration +class TestSGPTDefaults(unittest.TestCase): + def test_sgpt_defaults(self): + from llmcompressor.core.state import State + from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier + + kwargs = {"sparsity": 0.5} + sparsegpt_modifier_only_sparsity = SparseGPTModifier(**kwargs) + self.assertEqual(sparsegpt_modifier_only_sparsity.block_size, 128) + self.assertEqual(sparsegpt_modifier_only_sparsity.sparsity, 0.5) + + # fail if we don't pass a sparsity or enable quantization + kwargs = {} + sparsegpt_invalid = SparseGPTModifier(**kwargs) + state_test = State() + sparsegpt_invalid.initialized_structure_ = True + with self.assertRaises(ValueError): + sparsegpt_invalid.on_initialize(state=state_test) diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml index 54239b3b4..b9aa59e06 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml @@ -1,8 +1,9 @@ test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 + sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml index 7b795ba8e..b4f61ff9f 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml @@ -5,10 +5,11 @@ model: "Xenova/llama2.c-stories15M" dataset: open_platypus recipe: | test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 + sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml index 712413a31..6443c09c7 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml @@ -6,10 +6,11 @@ dataset: "gsm8k" dataset_config_name: "main" recipe: | test_stage: - obcq_modifiers: + modifiers: SparseGPTModifier: sparsity: 0.5 block_size: 128 + sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/oneshot/test_cli.py b/tests/llmcompressor/transformers/oneshot/test_cli.py index 08273b367..5780ca46f 100644 --- a/tests/llmcompressor/transformers/oneshot/test_cli.py +++ b/tests/llmcompressor/transformers/oneshot/test_cli.py @@ -49,7 +49,6 @@ def test_one_shot_cli(self): if len(self.additional_args) > 0: cmd.extend(self.additional_args) res = run_cli_command(cmd) - self.assertEqual(res.returncode, 0) print(res.stdout) diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index eeb6e95ae..df9726647 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -6,18 +6,12 @@ import torch from accelerate import cpu_offload from accelerate.accelerator import get_state_dict_offloaded_model -from compressed_tensors import QUANTIZATION_CONFIG_NAME, CompressionFormat +from compressed_tensors import QUANTIZATION_CONFIG_NAME from compressed_tensors.compressors import ModelCompressor from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig -from compressed_tensors.quantization import ( - QuantizationConfig, - QuantizationStatus, - quantize, -) +from compressed_tensors.quantization import QuantizationStatus from compressed_tensors.utils import get_offloaded_device, update_prefix_dict -from torch import nn from transformers import AutoConfig, AutoModelForCausalLM -from transformers.utils.quantization_config import CompressedTensorsConfig from llmcompressor.core import reset_session from llmcompressor.pytorch.utils.helpers import tensor_sparsity @@ -26,7 +20,6 @@ SparsityConfigMetadata, ) from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( - get_model_compressor, modify_save_pretrained, patch_tied_tensors_bug, ) @@ -178,8 +171,9 @@ def test_quant_model_reload(format, dtype, tmp_path): device = "cpu" dataset = "open_platypus" concatenate_data = False - num_calibration_samples = 16 + num_calibration_samples = 64 splits = {"calibration": "train[:10%]"} + empty_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype) # create a quantized model oneshot( @@ -197,7 +191,7 @@ def test_quant_model_reload(format, dtype, tmp_path): # Fetch the oneshot model model = get_session_model() og_state_dict = model.state_dict() - save_path_compressed = tmp_path / "compressed" + path = tmp_path / "compressed" for _, module in model.named_modules(): if hasattr(module, "quantization_scheme"): @@ -206,24 +200,32 @@ def test_quant_model_reload(format, dtype, tmp_path): # Save to disk model.save_pretrained( - save_path_compressed, + path, quantization_format=format, save_compressed=True, ) # Verify config on disk - config = AutoConfig.from_pretrained(save_path_compressed) + config = AutoConfig.from_pretrained(path) compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) quant_config = ModelCompressor.parse_quantization_config(compression_config) assert quant_config["format"] == format - decompressed_model = AutoModelForCausalLM.from_pretrained( - save_path_compressed, - torch_dtype=dtype, - quantization_config=CompressedTensorsConfig(run_compressed=False), - ) + # As HFQuantizer doesn't decompress the model, use the compressor to decompress + # the model instead + compressor = ModelCompressor.from_compression_config(compression_config) + compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN + compressor.decompress(model_path=path, model=empty_model) - reconstructed_state_dict = decompressed_model.state_dict() + # eventually use this pathway once HFQuant Decompression works + """ + dense_model = SparseAutoModelForCausalLM.from_pretrained( + "compress_out", torch_dtype="auto", device_map=device + ) + """ + # Verify the abs difference between the decompressed model + # and the original model + reconstructed_state_dict = empty_model.state_dict() assert len(og_state_dict) == len(reconstructed_state_dict) for key in og_state_dict.keys(): dense_tensor = og_state_dict[key].to(device) @@ -362,346 +364,3 @@ def test_model_shared_tensors_gpu( test_model_shared_tensors( offload, torch_dtype, tie_word_embeddings, device_map, tmp_path ) - - -@pytest.mark.parametrize( - "model_stub, recipe, sparse_format, quant_format", - [ - ( - "Xenova/llama2.c-stories15M", - "tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml", - CompressionFormat.sparse_24_bitmask.value, - CompressionFormat.float_quantized.value, - ), - ], -) -def test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tmp_path): - from llmcompressor.pytorch.model_load.helpers import get_session_model - - device = "cuda" - if not torch.cuda.is_available(): - device = "cpu" - dataset = "open_platypus" - concatenate_data = False - num_calibration_samples = 64 - splits = {"calibration": "train[:10%]"} - empty_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype="auto") - - oneshot( - model=model_stub, - dataset=dataset, - num_calibration_samples=num_calibration_samples, - recipe=recipe, - concatenate_data=concatenate_data, - splits=splits, - oneshot_device=device, - clear_sparse_session=False, - ) - - # Fetch the oneshot model - model = get_session_model() - og_state_dict = model.state_dict() - path = tmp_path / "compressed" - - # Compress and save - model.save_pretrained( - path, - quantization_format=quant_format, - save_compressed=True, - ) - - # Verify config on disk - config = AutoConfig.from_pretrained(path) - compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) - quant_config = ModelCompressor.parse_quantization_config(compression_config) - - # As HFQuantizer doesn't decompress the model, use the compressor to decompress - # the model instead - compressor = ModelCompressor.from_compression_config(compression_config) - - assert ( - compressor.sparsity_compressor is not None - ), "Sparse compressor not initialized" - assert compressor.sparsity_config.format == sparse_format - - assert ( - compressor.quantization_compressor is not None - ), "Quantization compressor not initialized" - assert quant_config["format"] == quant_format - - compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN - compressor.decompress(model_path=path, model=empty_model) - - # Verify the abs difference between the decompressed model - # and the original model - reconstructed_state_dict = empty_model.state_dict() - assert len(og_state_dict) == len(reconstructed_state_dict) - for key in og_state_dict.keys(): - dense_tensor = og_state_dict[key].to(device) - reconstructed_tensor = reconstructed_state_dict[key].to(device) - assert dense_tensor.dtype == reconstructed_tensor.dtype - if key.endswith("weight") and quant_format != "dense": - # we don't expect an exact match for compressed - diff = torch.abs(dense_tensor - reconstructed_tensor) - # max diff value found empirically - assert not torch.any(diff > 0.022), f"Max diff: {torch.max(diff)}" - else: - assert torch.equal(dense_tensor, reconstructed_tensor) - shutil.rmtree(tmp_path) - - -@pytest.mark.parametrize( - "model_stub, recipe, sparse_format", - [ - ( - "Xenova/llama2.c-stories15M", - "tests/llmcompressor/transformers/compression/recipes/sparse_24.yaml", - CompressionFormat.sparse_24_bitmask.value, - ), - ], -) -def test_sparse_24_compressor_is_lossless(model_stub, recipe, sparse_format, tmp_path): - from llmcompressor.pytorch.model_load.helpers import get_session_model - - device = "cuda" - if not torch.cuda.is_available(): - device = "cpu" - dataset = "open_platypus" - concatenate_data = False - num_calibration_samples = 64 - splits = {"calibration": "train[:10%]"} - empty_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype="auto") - - oneshot( - model=model_stub, - dataset=dataset, - num_calibration_samples=num_calibration_samples, - recipe=recipe, - concatenate_data=concatenate_data, - splits=splits, - oneshot_device=device, - clear_sparse_session=False, - ) - - # Fetch the oneshot model - model = get_session_model() - og_state_dict = model.state_dict() - path = tmp_path / "compressed" - - # Compress and save - model.save_pretrained( - path, - save_compressed=True, - ) - - # Verify config on disk - config = AutoConfig.from_pretrained(path) - compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) - - # As HFQuantizer doesn't decompress the model, use the compressor to decompress - # the model instead - compressor = ModelCompressor.from_compression_config(compression_config) - - assert ( - compressor.sparsity_compressor is not None - ), "Sparse compressor not initialized" - assert compressor.sparsity_config.format == sparse_format - - compressor.decompress(model_path=path, model=empty_model) - - # Verify the abs difference between the decompressed model - # and the original model - reconstructed_state_dict = empty_model.state_dict() - assert len(og_state_dict) == len(reconstructed_state_dict) - for key in og_state_dict.keys(): - dense_tensor = og_state_dict[key].to(device) - reconstructed_tensor = reconstructed_state_dict[key].to(device) - assert dense_tensor.dtype == reconstructed_tensor.dtype - if key.endswith("weight"): - assert torch.equal(dense_tensor, reconstructed_tensor) - shutil.rmtree(tmp_path) - - -def test_disable_sparse_compression_flag(tmp_path): - two_four_sparse_model_id = "nm-testing/llama2.c-stories42M-pruned2.4" - two_four_sparse_model = AutoModelForCausalLM.from_pretrained( - two_four_sparse_model_id, torch_dtype="auto" - ) - modify_save_pretrained(two_four_sparse_model) - - save_path = tmp_path / "no_sparse_compression_model" - two_four_sparse_model.save_pretrained(save_path, disable_sparse_compression=True) - - config = AutoConfig.from_pretrained(save_path) - quantization_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) - - assert quantization_config - sparsity_config = quantization_config.get("sparsity_config") - - assert sparsity_config - assert sparsity_config["format"] == "dense" - shutil.rmtree(tmp_path) - - -class DummyLinearModel(nn.Module): - """ - A dummy linear model for testing purposes, simulating a quantized linear layer. - """ - - def __init__(self, weights, weight_scale=None, weight_zero_point=None): - super().__init__() - out_features, in_features = weights.shape - - # Linear layer without bias - self.linear = nn.Linear(in_features, out_features, bias=False) - self.linear.weight = nn.Parameter(weights, requires_grad=False) - - # Attach scale and zero-point if provided - if weight_scale is not None: - self.linear.weight_scale = nn.Parameter( - torch.tensor(weight_scale), requires_grad=False - ) - if weight_zero_point is not None: - self.linear.weight_zero_point = nn.Parameter( - torch.tensor(weight_zero_point), requires_grad=False - ) - - def forward(self, x): - return self.linear(x) - - -def _create_quantization_config( - w_bits=8, - w_type="int", - w_strategy="tensor", - quantize_activations=False, - a_bits=8, - a_type="int", - a_strategy="tensor", -): - """ - Create a quantization configuration for testing. - """ - config_dict = { - "global_compression_ratio": 1.0, - "quant_method": "compressed-tensors", - "config_groups": { - "group_0": { - "targets": ["Linear"], - "weights": { - "num_bits": w_bits, - "strategy": w_strategy, - "symmetric": True, - "type": w_type, - }, - } - }, - } - - if quantize_activations: - config_dict["config_groups"]["group_0"]["input_activations"] = { - "num_bits": a_bits, - "strategy": a_strategy, - "symmetric": True, - "type": a_type, - } - - return QuantizationConfig.model_validate(config_dict) - - -def _quantization_config_from_string(config_str, q_type): - """ - Parse quantization config from string and type. - """ - w_bits = int(config_str[1]) - a_bits = int(config_str[3:]) - quantize_activations = a_bits < 16 - - return _create_quantization_config( - w_bits=w_bits, - w_type=q_type, - w_strategy="channel", - quantize_activations=quantize_activations, - a_bits=a_bits, - a_type=q_type, - a_strategy="channel", - ) - - -def _make_24_sparse(tensor): - """ - Apply 2:4 sparsity pattern to the given tensor. - """ - reshaped_tensor = tensor.view(tensor.size(0), -1, 4) - mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool) - mask[..., :2] = True - sparsified_tensor = torch.where( - mask, reshaped_tensor, torch.tensor(0.0, dtype=tensor.dtype) - ) - return sparsified_tensor.view_as(tensor) - - -@pytest.mark.parametrize( - "quant_style, quant_type, is_24, expected_quant_compressor, " - "expected_sparsity_compressor", - [ - ("W8A8", "int", False, "int-quantized", "dense"), - ("W4A16", "int", False, "pack-quantized", "dense"), - ("W8A16", "int", False, "pack-quantized", "dense"), - ("W8A8", "int", True, "int-quantized", "sparse-24-bitmask"), - ("W4A16", "int", True, "marlin-24", "dense"), - ("W8A16", "int", True, "marlin-24", "dense"), - ("W8A8", "float", False, "float-quantized", "dense"), - ("W8A16", "float", False, "naive-quantized", "dense"), - ("W8A8", "float", True, "float-quantized", "sparse-24-bitmask"), - ("W8A16", "float", True, "naive-quantized", "dense"), - ], -) -def test_correct_compressor_inferred( - quant_style, - quant_type, - is_24, - expected_quant_compressor, - expected_sparsity_compressor, -): - """ - Test if the correct compressor is inferred based on - quantization and sparsity configurations. - """ - weights = torch.rand(10, 4) - if is_24: - weights = _make_24_sparse(weights) - else: - weights[0, :] = torch.ones( - 4, - ) # guarantee not 24 sparse - - quantization_config = _quantization_config_from_string(quant_style, quant_type) - quantization_args = quantization_config.config_groups["group_0"].weights - - scale = ( - torch.ones((weights.shape[0], 1)) - if quantization_args.strategy == "channel" - else torch.tensor([1.0]) - ) - zero_point = torch.zeros_like(scale) - - quantized_weights = quantize( - weights, scale=scale, zero_point=zero_point, args=quantization_args - ) - - model = DummyLinearModel(quantized_weights, scale, zero_point) - model.linear.quantization_scheme = quantization_config.config_groups["group_0"] - model.linear.quantization_status = QuantizationStatus.FROZEN - - compressor = get_model_compressor(model) - - assert compressor.quantization_config.format == expected_quant_compressor - - if expected_sparsity_compressor == "dense": - assert ( - compressor.sparsity_config is None - or compressor.sparsity_config.format == expected_sparsity_compressor - ) - else: - assert compressor.sparsity_config.format == expected_sparsity_compressor diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 257506784..a6103a73c 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -135,8 +135,7 @@ def preprocess_tokenize_dataset( :param tokenizer: tokenizer to be used for tokenization :param max_seq_length: maximum sequence length of samples """ - ds_name = ds.info.dataset_name.lower() - if ds_name == "gsm8k": + if ds.info.dataset_name == "gsm8k": def preprocess(example): return example @@ -149,8 +148,7 @@ def tokenize(sample): truncation=True, add_special_tokens=False, ) - - elif ds_name == "ultrachat_200k": + elif ds.info.dataset_name == "ultrachat_200k": def preprocess(example): return { @@ -168,69 +166,6 @@ def tokenize(sample): truncation=True, add_special_tokens=False, ) - - elif ds_name == "llm_compression_calibration": - - def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["text"], - tokenize=False, - ) - } - - def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=max_seq_length, - truncation=True, - add_special_tokens=False, - ) - - elif ds_name == "open-platypus": - # use the output rather than the instruction - def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["output"], - tokenize=False, - ) - } - - def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=max_seq_length, - truncation=True, - add_special_tokens=False, - ) - - elif ds_name == "slimorca-deduped-cleaned-corrected": - # find the first element corresponding to a message from a human - def preprocess(example): - conversation_idx = 0 - for idx, conversation in enumerate(example["conversations"]): - if conversation["from"] == "human": - conversation_idx = idx - break - return { - "text": tokenizer.apply_chat_template( - example["conversations"][conversation_idx]["value"], - tokenize=False, - ) - } - - def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=max_seq_length, - truncation=True, - add_special_tokens=False, - ) - else: raise NotImplementedError(f"Cannot preprocess dataset {ds.info.dataset_name}") From 5b384bbf9039b08a79fb9ee09907db9a0b409dbf Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 13 Feb 2025 17:08:23 -0500 Subject: [PATCH 2/2] merge with main Signed-off-by: Kyle Sayers --- examples/automodelforcausallm/README.md | 13 - .../run_automodelforcausallm.py | 11 - .../example_alternating_recipe.yaml | 12 +- .../2of4_w4a16_group-128_recipe.yaml | 3 +- .../2of4_w4a16_recipe.yaml | 3 +- examples/quantizing_moe/deepseek_moe_w4a16.py | 4 + .../quantizing_moe/deepseek_moe_w8a8_fp8.py | 4 + .../quantizing_moe/deepseek_moe_w8a8_int8.py | 4 + .../sparse_2of4_quantization_fp8/README.md | 13 +- .../llama3_8b_2of4.py | 4 +- examples/trl_mixin/ex_trl_constant.py | 2 +- examples/trl_mixin/ex_trl_distillation.py | 9 +- examples/trl_mixin/sft_trainer.py | 2 +- .../transformers/finetune/README.md | 7 +- .../2of4_w4a16_group-128_recipe.yaml | 3 +- .../recipes/WNA16_2of4/2of4_w4a16_recipe.yaml | 3 +- tests/e2e/vLLM/test_vllm.py | 86 +++- tests/examples/utils.py | 7 +- .../modifiers/calibration/test_cache.py | 2 +- tests/llmcompressor/modifiers/conf.py | 6 +- .../gptq/utils/test_gptq_wrapper.py | 41 -- .../modifiers/smoothquant/test_utils.py | 5 +- .../modifiers/utils/test_hooks.py | 93 +++++ tests/llmcompressor/observers/test_min_max.py | 8 +- tests/llmcompressor/observers/test_mse.py | 4 +- tests/llmcompressor/pytorch/helpers.py | 2 + .../pruning/sparsegpt/test_pytorch.py | 50 +-- .../run_compressed_configs/fp8_dynamic.yaml | 4 +- .../run_compressed_configs/w4a16.yaml | 4 +- .../run_compressed_configs/w8a8.yaml | 4 +- .../compression/test_infer_quant_format.py | 6 +- .../compression/test_quantization.py | 5 +- .../compression/test_run_compressed.py | 140 +++++-- .../transformers/finetune/data/conftest.py | 2 +- .../finetune/data/test_dataset_helpers.py | 6 +- .../finetune/data/test_dataset_loading.py | 34 +- .../finetune/data/test_registry.py | 8 +- .../finetune/test_alternate_recipe.yaml | 4 +- .../finetune/test_oneshot_then_finetune.py | 100 ++++- .../finetune/test_session_mixin.py | 4 + .../transformers/gptq/test_oneshot.py | 2 +- .../obcq/recipes/additional_sparsity.yaml | 1 - .../additional_sparsity_with_quant.yaml | 1 - .../obcq/recipes/quant_and_sparse.yaml | 1 - .../transformers/obcq/recipes/sparse.yaml | 1 - .../recipes/sparse_with_mask_structure.yaml | 1 - .../transformers/obcq/recipes/test_tiny2.yaml | 1 - .../obcq/test_consecutive_runs.py | 44 +- .../transformers/obcq/test_obcq_completion.py | 6 +- .../obcq/test_obcq_infer_targets.py | 28 +- .../transformers/obcq/test_obcq_lm_head.py | 39 +- .../transformers/obcq/test_sgpt_defaults.py | 23 -- .../oneshot_configs/recipes/recipe.yaml | 1 - .../oneshot_configs/tiny_stories_conf1.yaml | 1 - .../oneshot_configs/tiny_stories_conf4.yaml | 1 - .../transformers/oneshot/test_cli.py | 1 + .../test_compress_tensor_utils.py | 383 +++++++++++++++++- tests/testing_utils.py | 69 +++- 58 files changed, 986 insertions(+), 340 deletions(-) delete mode 100644 examples/automodelforcausallm/README.md delete mode 100644 examples/automodelforcausallm/run_automodelforcausallm.py delete mode 100644 tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py delete mode 100644 tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py diff --git a/examples/automodelforcausallm/README.md b/examples/automodelforcausallm/README.md deleted file mode 100644 index e40cb5c2a..000000000 --- a/examples/automodelforcausallm/README.md +++ /dev/null @@ -1,13 +0,0 @@ -# Loading models using `AutoModelForCausalLM` - -Models quantized through `llm-compressor` can be loaded directly through -`AutoModelForCausalLM`. Note: this requires `transformers>=v4.45.0` and -`compressed-tensors>v0.6.0`. - -```python -from transformers import AutoModelForCausalLM - -MODEL_ID = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer" - -model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto") -``` diff --git a/examples/automodelforcausallm/run_automodelforcausallm.py b/examples/automodelforcausallm/run_automodelforcausallm.py deleted file mode 100644 index 791b4d3d5..000000000 --- a/examples/automodelforcausallm/run_automodelforcausallm.py +++ /dev/null @@ -1,11 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer - -MODEL_ID = "nm-testing/tinyllama-w8a8-compressed-hf-quantizer" - -# Use the AutoModelForCausalLM to run the model -model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto") -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids -output = model.generate(input_ids, max_new_tokens=100) -print(tokenizer.decode(output[0])) diff --git a/examples/finetuning/example_alternating_recipe.yaml b/examples/finetuning/example_alternating_recipe.yaml index a3be682a4..5f4b3018e 100644 --- a/examples/finetuning/example_alternating_recipe.yaml +++ b/examples/finetuning/example_alternating_recipe.yaml @@ -4,12 +4,10 @@ initial_sparsity_stage: SparseGPTModifier: sparsity: 0.5 block_size: 128 - sequential_update: False percdamp: 0.01 mask_structure: "0:0" - targets: [ - "re:model.layers.\\d+$" - ] + targets: ["Linear"] + ignore: ["re:.*lm_head"] initial_training_stage: run_type: train pruning_modifiers: @@ -22,12 +20,10 @@ next_sparsity_stage: SparseGPTModifier: sparsity: 0.7 block_size: 128 - sequential_update: False percdamp: 0.01 mask_structure: "0:0" - targets: [ - "re:model.layers.\\d+$" - ] + targets: ["Linear"] + ignore: ["re:.*lm_head"] next_training_stage: run_type: train pruning_modifiers: diff --git a/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_group-128_recipe.yaml b/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_group-128_recipe.yaml index 166e41a66..e59cf8a96 100644 --- a/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_group-128_recipe.yaml +++ b/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_group-128_recipe.yaml @@ -4,7 +4,8 @@ sparsity_stage: SparseGPTModifier: sparsity: 0.5 mask_structure: "2:4" - sequential_update: false + targets: ["Linear"] + ignore: ["re:.*lm_head"] finetuning_stage: run_type: train finetuning_modifiers: diff --git a/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_recipe.yaml b/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_recipe.yaml index 2ad00b457..4ff5ff26e 100644 --- a/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_recipe.yaml +++ b/examples/quantization_2of4_sparse_w4a16/2of4_w4a16_recipe.yaml @@ -4,7 +4,8 @@ sparsity_stage: SparseGPTModifier: sparsity: 0.5 mask_structure: "2:4" - sequential_update: false + targets: ["Linear"] + ignore: ["re:.*lm_head"] finetuning_stage: run_type: train finetuning_modifiers: diff --git a/examples/quantizing_moe/deepseek_moe_w4a16.py b/examples/quantizing_moe/deepseek_moe_w4a16.py index 3d7d33099..55a7021b4 100644 --- a/examples/quantizing_moe/deepseek_moe_w4a16.py +++ b/examples/quantizing_moe/deepseek_moe_w4a16.py @@ -5,6 +5,10 @@ from llmcompressor.transformers import oneshot from llmcompressor.transformers.compression.helpers import calculate_offload_device_map +# NOTE: transformers 4.48.0 has an import error with DeepSeek. +# Please consider either downgrading your transformers version to a +# previous version or upgrading to a version where this bug is fixed + # select a Mixture of Experts model for quantization MODEL_ID = "deepseek-ai/DeepSeek-V2.5" diff --git a/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py b/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py index 666da8f9a..cda202eb9 100644 --- a/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py +++ b/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py @@ -4,6 +4,10 @@ from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.transformers import oneshot +# NOTE: transformers 4.48.0 has an import error with DeepSeek. +# Please consider either downgrading your transformers version to a +# previous version or upgrading to a version where this bug is fixed + # select a Mixture of Experts model for quantization MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" diff --git a/examples/quantizing_moe/deepseek_moe_w8a8_int8.py b/examples/quantizing_moe/deepseek_moe_w8a8_int8.py index ba215aa9e..289f4234f 100644 --- a/examples/quantizing_moe/deepseek_moe_w8a8_int8.py +++ b/examples/quantizing_moe/deepseek_moe_w8a8_int8.py @@ -6,6 +6,10 @@ from llmcompressor.transformers import oneshot from llmcompressor.transformers.compression.helpers import calculate_offload_device_map +# NOTE: transformers 4.48.0 has an import error with DeepSeek. +# Please consider either downgrading your transformers version to a +# previous version or upgrading to a version where this bug is fixed + # select a Mixture of Experts model for quantization MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" diff --git a/examples/sparse_2of4_quantization_fp8/README.md b/examples/sparse_2of4_quantization_fp8/README.md index 97b8e590e..99fc3c545 100644 --- a/examples/sparse_2of4_quantization_fp8/README.md +++ b/examples/sparse_2of4_quantization_fp8/README.md @@ -93,7 +93,7 @@ oneshot( ) ``` -3. **Save the Compressed Model** +### Saving the Compressed Model The compressed model and tokenizer are saved to the output directory: @@ -106,6 +106,17 @@ Output Directories: - Without FP8: `Meta-Llama-3-8B-Instruct-2of4-sparse` - With FP8: `Meta-Llama-3-8B-Instruct-2of4-W8A8-FP8-Dynamic-Per-Token` +#### Saving Without Sparse Compression + +To save the model on disk without sparse compression: + +```python +model.save_pretrained(save_dir, save_compressed=True, disable_sparse_compression=True) +tokenizer.save_pretrained(save_dir) +``` + +> **Note:** Saving a model with both the `save_compressed` and `disable_sparse_compression` options will compress the model using the quantization compressor; however, instead of using the more disk-efficient sparsity compressor(s), the dense sparsity compressor will be used. The `dense` sparsity compressor saves model params as is, and does not leverage sparsity for disk-efficient storage. These options only affect how the model(s) are saved on disk and do not impact the actual pruning or quantization processes. + ### Validation After compression, the script validates the model by generating a sample output: diff --git a/examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py b/examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py index e8133225f..39620a814 100644 --- a/examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py +++ b/examples/sparse_2of4_quantization_fp8/llama3_8b_2of4.py @@ -115,5 +115,7 @@ def get_recipe(fp8_enabled): print("==========================================\n") # Save compressed model and tokenizer -model.save_pretrained(save_dir, save_compressed=args.fp8) +model.save_pretrained( + save_dir, save_compressed=args.fp8, disable_sparse_compression=True +) tokenizer.save_pretrained(save_dir) diff --git a/examples/trl_mixin/ex_trl_constant.py b/examples/trl_mixin/ex_trl_constant.py index b2f597ec8..517d74d71 100644 --- a/examples/trl_mixin/ex_trl_constant.py +++ b/examples/trl_mixin/ex_trl_constant.py @@ -3,7 +3,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from trl import DataCollatorForCompletionOnlyLM -from llmcompressor.transformers import TrainingArguments +from llmcompressor.args import TrainingArguments model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data" diff --git a/examples/trl_mixin/ex_trl_distillation.py b/examples/trl_mixin/ex_trl_distillation.py index ff3ddf000..96cc78846 100644 --- a/examples/trl_mixin/ex_trl_distillation.py +++ b/examples/trl_mixin/ex_trl_distillation.py @@ -1,11 +1,8 @@ from sft_trainer import SFTTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator -from llmcompressor.transformers import ( - DataTrainingArguments, - TextGenerationDataset, - TrainingArguments, -) +from llmcompressor.args import DatasetArguments, TrainingArguments +from llmcompressor.transformers import TextGenerationDataset model_path = "neuralmagic/Llama-2-7b-pruned50-retrained" teacher_path = "neuralmagic/Llama-2-7b-gsm8k" @@ -21,7 +18,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_path) # Load gsm8k using SparseML dataset tools -data_args = DataTrainingArguments( +data_args = DatasetArguments( dataset="gsm8k", dataset_config_name="main", max_seq_length=512 ) dataset_manager = TextGenerationDataset.load_from_registry( diff --git a/examples/trl_mixin/sft_trainer.py b/examples/trl_mixin/sft_trainer.py index c311cf8dc..2577c0cc7 100644 --- a/examples/trl_mixin/sft_trainer.py +++ b/examples/trl_mixin/sft_trainer.py @@ -1,7 +1,7 @@ from trl import SFTConfig as TRLSFTConfig from trl import SFTTrainer as TRLSFTTrainer -from llmcompressor.transformers import TrainingArguments +from llmcompressor.args import TrainingArguments from llmcompressor.transformers.finetune.session_mixin import SessionManagerMixIn __all__ = ["SFTTrainer"] diff --git a/src/llmcompressor/transformers/finetune/README.md b/src/llmcompressor/transformers/finetune/README.md index 387da51f1..8669d810e 100644 --- a/src/llmcompressor/transformers/finetune/README.md +++ b/src/llmcompressor/transformers/finetune/README.md @@ -74,9 +74,10 @@ train( Finetuning arguments are split up into 3 groups: -* ModelArguments: `src/llmcompressor/transformers/finetune/model_args.py` -* TrainingArguments: `src/llmcompressor/transformers/finetune/training_args.py` -* DataTrainingArguments: `src/llmcompressor/transformers/finetune/data/data_training_args.py` +* ModelArguments: `src/llmcompressor/transformers/utils/arg_parser/model_arguments.py` +* TrainingArguments: `src/llmcompressor/transformers/utils/arg_parser/training_arguments.py` +* DatasetArguments: `src/llmcompressor/transformers/utils/arg_parser/dataset_arguments.py` +* RecipeArguments: `src/llmcompressor/transformers/utils/arg_parser/recipe_arguments.py` ## Running One-Shot with FSDP diff --git a/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_group-128_recipe.yaml b/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_group-128_recipe.yaml index 7523b09a7..92cc85ae7 100644 --- a/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_group-128_recipe.yaml +++ b/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_group-128_recipe.yaml @@ -4,7 +4,8 @@ sparsity_stage: SparseGPTModifier: sparsity: 0.5 mask_structure: "2:4" - sequential_update: false + targets: ["Linear"] + ignore: ["re:.*lm_head"] quantization_stage: run_type: oneshot quantization_modifiers: diff --git a/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_recipe.yaml b/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_recipe.yaml index b8a4402d8..dc7e18b6e 100644 --- a/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_recipe.yaml +++ b/tests/e2e/vLLM/recipes/WNA16_2of4/2of4_w4a16_recipe.yaml @@ -4,7 +4,8 @@ sparsity_stage: SparseGPTModifier: sparsity: 0.5 mask_structure: "2:4" - sequential_update: false + targets: ["Linear"] + ignore: ["re:.*lm_head"] quantization_stage: run_type: oneshot quantization_modifiers: diff --git a/tests/e2e/vLLM/test_vllm.py b/tests/e2e/vLLM/test_vllm.py index b31bfb007..6c42f82df 100644 --- a/tests/e2e/vLLM/test_vllm.py +++ b/tests/e2e/vLLM/test_vllm.py @@ -1,12 +1,13 @@ import os +import re import shutil from pathlib import Path -from typing import Callable import pytest import yaml from huggingface_hub import HfApi from loguru import logger +from parameterized import parameterized_class from llmcompressor.core import active_session from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing @@ -20,19 +21,24 @@ vllm_installed = False logger.warning("vllm is not installed. This test will be skipped") + HF_MODEL_HUB_NAME = "nm-testing" -TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", "") +TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", "") +SKIP_HF_UPLOAD = os.environ.get("SKIP_HF_UPLOAD", "") -@pytest.fixture -def record_config_file(record_testsuite_property: Callable[[str, object], None]): - test_data_file_name = TEST_DATA_FILE.split("configs/")[-1] - record_testsuite_property("TEST_DATA_FILE_NAME", test_data_file_name) +EXPECTED_SAVED_FILES = [ + "config.json", + r"^model(?:-\d{5}-of-\d{5})?\.safetensors$", + "recipe.yaml", + "tokenizer.json", +] # Will run each test case in its own process through run_tests.sh # emulating vLLM CI testing @requires_gpu_count(1) +@parameterized_class("test_data_file", [(TEST_DATA_FILE,)]) @pytest.mark.skipif(not vllm_installed, reason="vLLM is not installed, skipping test") class TestvLLM: """ @@ -52,7 +58,9 @@ class TestvLLM: """ # noqa: E501 def set_up(self): - eval_config = yaml.safe_load(Path(TEST_DATA_FILE).read_text(encoding="utf-8")) + eval_config = yaml.safe_load( + Path(self.test_data_file).read_text(encoding="utf-8") + ) if os.environ.get("CADENCE", "commit") != eval_config.get("cadence"): pytest.skip("Skipping test; cadence mismatch") @@ -65,6 +73,7 @@ def set_up(self): self.recipe = eval_config.get("recipe") self.quant_type = eval_config.get("quant_type") self.save_dir = eval_config.get("save_dir") + self.save_compressed = eval_config.get("save_compressed", True) logger.info("========== RUNNING ==============") logger.info(self.scheme) @@ -79,7 +88,6 @@ def set_up(self): ] self.api = HfApi() - @pytest.mark.usefixtures("record_config_file") def test_vllm(self): # Run vLLM with saved model import torch @@ -100,11 +108,19 @@ def test_vllm(self): quant_type=self.quant_type, ) + # check that session contains recipe + self._check_session_contains_recipe() + logger.info("================= SAVING TO DISK ======================") - oneshot_model.save_pretrained(self.save_dir) + oneshot_model.save_pretrained( + self.save_dir, save_compressed=self.save_compressed + ) tokenizer.save_pretrained(self.save_dir) recipe_path = os.path.join(self.save_dir, "recipe.yaml") + # check that expected files exist + self._check_save_dir_has_expected_files() + # Use the session to fetch the recipe; # Reset session for next test case session = active_session() @@ -113,12 +129,22 @@ def test_vllm(self): fp.write(recipe_yaml_str) session.reset() - logger.info("================= UPLOADING TO HUB ======================") + if SKIP_HF_UPLOAD.lower() != "yes": + logger.info("================= UPLOADING TO HUB ======================") - self.api.upload_folder( - repo_id=f"{HF_MODEL_HUB_NAME}/{self.save_dir}-e2e", - folder_path=self.save_dir, - ) + stub = f"{HF_MODEL_HUB_NAME}/{self.save_dir}-e2e" + + self.api.create_repo( + repo_id=stub, + exist_ok=True, + repo_type="model", + private=False, + ) + + self.api.upload_folder( + repo_id=stub, + folder_path=self.save_dir, + ) logger.info("================= RUNNING vLLM =========================") @@ -146,3 +172,35 @@ def test_vllm(self): def tear_down(self): if self.save_dir is not None: shutil.rmtree(self.save_dir) + + def _check_session_contains_recipe(self) -> None: + session = active_session() + recipe_yaml_str = session.get_serialized_recipe() + assert recipe_yaml_str is not None + + def _check_save_dir_has_expected_files(self): + files = os.listdir(self.save_dir) + logger.debug("Saved files: ", files) + + matched_patterns = set() + + for expected in EXPECTED_SAVED_FILES: + # Find all files matching the expected pattern + matches = [ + file + for file in files + if ( + re.fullmatch(expected, file) + if expected.startswith("^") + else file == expected + ) + ] + if len(matches) > 0: + matched_patterns.add(expected) + + assert len(matched_patterns) == len(EXPECTED_SAVED_FILES), ( + "expected: ", + EXPECTED_SAVED_FILES, + "\n saved: ", + list(matched_patterns), + ) diff --git a/tests/examples/utils.py b/tests/examples/utils.py index 38ff98d64..29eba8dd4 100644 --- a/tests/examples/utils.py +++ b/tests/examples/utils.py @@ -68,7 +68,10 @@ def copy_and_run_command( def copy_and_run_script( - tmp_path: Path, example_dir: str, script_filename: str + tmp_path: Path, + example_dir: str, + script_filename: str, + flags: Optional[list[str]] = None, ) -> Tuple[List[str], CompletedProcess[str]]: """ Copies the contents of example_dir (relative to the current working directory) to @@ -81,6 +84,8 @@ def copy_and_run_script( :return: subprocess.CompletedProcess object """ command = [sys.executable, script_filename] + if flags: + command.extend(flags) return command, copy_and_run_command(tmp_path, example_dir, command) diff --git a/tests/llmcompressor/modifiers/calibration/test_cache.py b/tests/llmcompressor/modifiers/calibration/test_cache.py index 6ea024037..898c342f5 100644 --- a/tests/llmcompressor/modifiers/calibration/test_cache.py +++ b/tests/llmcompressor/modifiers/calibration/test_cache.py @@ -28,7 +28,7 @@ def test_is_quantized_cache_singleton(): args = QuantizationArgs() cache = QuantizedKVParameterCache(args) - observer = args.get_observer() + observer = args.observer observer = Observer.load_from_registry(observer, quantization_args=args) tensor = torch.tensor([1, 2, 3]) diff --git a/tests/llmcompressor/modifiers/conf.py b/tests/llmcompressor/modifiers/conf.py index 3eab9b85c..0a910788c 100644 --- a/tests/llmcompressor/modifiers/conf.py +++ b/tests/llmcompressor/modifiers/conf.py @@ -1,3 +1,7 @@ +from unittest.mock import MagicMock + +from torch.utils.data import DataLoader + from llmcompressor.core import State from llmcompressor.core.events import EventType from llmcompressor.core.lifecycle import CallbacksEventLifecycle @@ -24,7 +28,7 @@ def __init__( optimizer=optimizer, start=start, steps_per_epoch=1, - calib_data=[], + calib_data=DataLoader(MagicMock(__len__=lambda _: 0, column_names=[])), ) self.event_lifecycle = CallbacksEventLifecycle( diff --git a/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py b/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py deleted file mode 100644 index 203d1fe03..000000000 --- a/tests/llmcompressor/modifiers/quantization/gptq/utils/test_gptq_wrapper.py +++ /dev/null @@ -1,41 +0,0 @@ -from collections import OrderedDict - -import torch -from compressed_tensors.quantization.lifecycle.apply import apply_quantization_config -from compressed_tensors.quantization.quant_config import QuantizationConfig -from compressed_tensors.quantization.quant_scheme import preset_name_to_scheme -from loguru import logger - -from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper - - -def test_ignore(): - model = torch.nn.Sequential( - OrderedDict( - [ - ("first_layer", torch.nn.Linear(2, 3)), - ("second_layer", torch.nn.Linear(3, 5)), - ] - ) - ) - - config = QuantizationConfig( - config_groups={"group_0": preset_name_to_scheme("W8A8", targets=["Linear"])}, - ignore=["first_layer"], - ) - apply_quantization_config(model, config) - - messages = [] - logger.add(lambda m: messages.append(m)) - - with torch.no_grad(): - first_compressor = GPTQWrapper("first_layer", model.first_layer) - first_compressor.add_batch(torch.ones(2), None) - first_compressor.compress() - - second_compressor = GPTQWrapper("second_layer", model.second_layer) - second_compressor.add_batch(torch.ones(3), None) - second_compressor.compress() - - assert sum("Skipping unquantized layer first_layer" in m for m in messages) == 1 - assert sum("Skipping unquantized layer second_layer" in m for m in messages) == 0 diff --git a/tests/llmcompressor/modifiers/smoothquant/test_utils.py b/tests/llmcompressor/modifiers/smoothquant/test_utils.py index 95be6bd30..457b64cdb 100644 --- a/tests/llmcompressor/modifiers/smoothquant/test_utils.py +++ b/tests/llmcompressor/modifiers/smoothquant/test_utils.py @@ -12,7 +12,10 @@ @pytest.mark.unit def test_handle_mapping_resolution_errors(): - README_LOCATION = "llmcompressor/modifiers/smoothquant/README.md" + README_LOCATION = ( + "https://github.com/vllm-project/llm-compressor/tree/main/" + "src/llmcompressor/modifiers/smoothquant" + ) @handle_mapping_resolution_errors def func_that_raises_exception(): diff --git a/tests/llmcompressor/modifiers/utils/test_hooks.py b/tests/llmcompressor/modifiers/utils/test_hooks.py index 5c4fc5891..2a402e980 100644 --- a/tests/llmcompressor/modifiers/utils/test_hooks.py +++ b/tests/llmcompressor/modifiers/utils/test_hooks.py @@ -64,6 +64,27 @@ def test_remove_hooks(): assert mod_a.hook_called and not mod_b.hook_called +def test_remove_hooks_parameterized(): + model = DummyModel() + + mod_a = ModA() + mod_a_pre_hook = mod_a.register_hook(model.linear1, mod_a.hook, "forward_pre") + mod_a_post_hook = mod_a.register_hook(model.linear1, mod_a.hook, "forward") + + mod_b = ModB() + mod_b_pre_hook = mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") + mod_b_post_hook = mod_b.register_hook(model.linear2, mod_b.hook, "forward") + + mod_a.remove_hooks(set([mod_a_post_hook])) + mod_b.remove_hooks(set([mod_b_pre_hook])) + + assert len(mod_a._hooks) == 1 and next(iter(mod_a._hooks)) == mod_a_pre_hook + assert len(mod_b._hooks) == 1 and next(iter(mod_b._hooks)) == mod_b_post_hook + + model(model.dummy_inputs) + assert mod_a.hook_called and mod_b.hook_called + + def test_disable_hooks(): model = DummyModel() @@ -81,3 +102,75 @@ def test_disable_hooks(): mod_b.hook_called = False model(model.dummy_inputs) assert mod_a.hook_called and mod_b.hook_called + + +def test_disable_hooks_keep(): + model = DummyModel() + + mod_a = ModA() + handle_a = mod_a.register_hook(model.linear1, mod_a.hook, "forward") + + mod_b = ModB() + handle_b = mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") + + with HooksMixin.disable_hooks(keep=set([handle_b])): + model(model.dummy_inputs) + assert not mod_a.hook_called and mod_b.hook_called + + mod_a.hook_called = False + mod_b.hook_called = False + with HooksMixin.disable_hooks(keep=set([handle_a])): + model(model.dummy_inputs) + assert mod_a.hook_called and not mod_b.hook_called + + mod_a.hook_called = False + mod_b.hook_called = False + model(model.dummy_inputs) + assert mod_a.hook_called and mod_b.hook_called + + +def test_disable_hooks_composable(): + model = DummyModel() + + mod_a = ModA() + handle_a = mod_a.register_hook(model.linear1, mod_a.hook, "forward") + + mod_b = ModB() + handle_b = mod_b.register_hook(model.linear2, mod_b.hook, "forward_pre") + + # composing two keeps + with ( + HooksMixin.disable_hooks(keep=set([handle_b])), + HooksMixin.disable_hooks(keep=set([handle_a])), + ): + model(model.dummy_inputs) + assert mod_a.hook_called and mod_b.hook_called + + mod_a.hook_called = False + mod_b.hook_called = False + model(model.dummy_inputs) + assert mod_a.hook_called and mod_b.hook_called + + mod_a.hook_called = False + mod_b.hook_called = False + with HooksMixin.disable_hooks(): + model(model.dummy_inputs) + assert not mod_a.hook_called and not mod_b.hook_called + + # composing a keep and an empty keep + mod_a.hook_called = False + mod_b.hook_called = False + with HooksMixin.disable_hooks(keep=set([handle_a])), HooksMixin.disable_hooks(): + model(model.dummy_inputs) + assert mod_a.hook_called and not mod_b.hook_called + + mod_a.hook_called = False + mod_b.hook_called = False + model(model.dummy_inputs) + assert mod_a.hook_called and mod_b.hook_called + + mod_a.hook_called = False + mod_b.hook_called = False + with HooksMixin.disable_hooks(): + model(model.dummy_inputs) + assert not mod_a.hook_called and not mod_b.hook_called diff --git a/tests/llmcompressor/observers/test_min_max.py b/tests/llmcompressor/observers/test_min_max.py index f23a06dba..b592579f6 100644 --- a/tests/llmcompressor/observers/test_min_max.py +++ b/tests/llmcompressor/observers/test_min_max.py @@ -37,7 +37,7 @@ def test_min_max_observer(symmetric, expected_scale, expected_zero_point): num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric) - observer = weights.get_observer() + observer = weights.observer observer = Observer.load_from_registry(observer, quantization_args=weights) scale, zero_point = observer(tensor) @@ -52,7 +52,7 @@ def test_min_max_observer_symmetric_scale_range(): num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - observer = weights.get_observer() + observer = weights.observer observer = Observer.load_from_registry(observer, quantization_args=weights) scale, zero_point = observer(tensor) @@ -80,7 +80,7 @@ def test_min_max_observer_value_update(): tensor = inp num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - observer = weights.get_observer() + observer = weights.observer observer = Observer.load_from_registry(observer, quantization_args=weights) curr_max = 1 curr_min = 1 @@ -107,7 +107,7 @@ def test_g_idx(): weights = QuantizationArgs(num_bits=8, group_size=group_size) g_idx = make_dummy_g_idx(tensor.shape[1], group_size) - observer = weights.get_observer() + observer = weights.observer observer = Observer.load_from_registry(observer, quantization_args=weights) scale_g_idx, zero_point_g_idx = observer(tensor, g_idx=g_idx) diff --git a/tests/llmcompressor/observers/test_mse.py b/tests/llmcompressor/observers/test_mse.py index ec2ecf1b5..4447813b3 100644 --- a/tests/llmcompressor/observers/test_mse.py +++ b/tests/llmcompressor/observers/test_mse.py @@ -32,7 +32,7 @@ def test_mse_observer(symmetric, expected_scale, expected_zero_point): num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse") - observer = weights.get_observer() + observer = weights.observer observer = Observer.load_from_registry(observer, quantization_args=weights) scale, zero_point = observer(tensor) @@ -48,7 +48,7 @@ def test_mse_observer_symmetric_scale_range(): num_bits = 8 weights = QuantizationArgs(num_bits=num_bits, symmetric=True) - observer = weights.get_observer() + observer = weights.observer observer = Observer.load_from_registry(observer, quantization_args=weights) scale, zero_point = observer(tensor) diff --git a/tests/llmcompressor/pytorch/helpers.py b/tests/llmcompressor/pytorch/helpers.py index d7b52a836..341c18f11 100644 --- a/tests/llmcompressor/pytorch/helpers.py +++ b/tests/llmcompressor/pytorch/helpers.py @@ -1,5 +1,6 @@ from collections import OrderedDict, namedtuple from typing import List +from unittest.mock import Mock import pytest import torch @@ -96,6 +97,7 @@ def __init__(self): ] ) ) + self.config = Mock(use_cache=False) def forward(self, inp: Tensor): return self.seq(inp) diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 0752f2a30..096a93111 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -4,7 +4,7 @@ from compressed_tensors.quantization import QuantizationScheme from parameterized import parameterized -from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier +from llmcompressor.modifiers.pruning import SparseGPTModifier from llmcompressor.modifiers.quantization.gptq import GPTQModifier from llmcompressor.modifiers.quantization.quantization import QuantizationModifier from llmcompressor.utils.pytorch.module import qat_active @@ -29,12 +29,11 @@ def setUp(self): ) def test_invalid_layerwise_recipes_raise_exceptions(self, sparsity, targets): setup_modifier_factory() - kwargs = dict( + modifier = SparseGPTModifier( sparsity=sparsity, block_size=128, targets=targets, ) - modifier = SparseGPTModifier(**kwargs) testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) # confirm invalid layerwise recipes fail at initialization @@ -50,16 +49,16 @@ def setUp(self): def test_successful_layerwise_recipe(self): sparsities = [0.5, 0.2] targets = ["seq.fc1", "seq.fc2"] - kwargs = dict(sparsity=sparsities, block_size=128, targets=targets) - modifier = SparseGPTModifier(**kwargs) - modifier.compressible_layers_ = {"seq.fc1": None, "seq.fc2": None} - modifier.model = LinearNet() - found_compressible_layers = modifier.compressible_layers() - modifier.compressible_layers_ = found_compressible_layers - modifier._validate_layerwise_sparsity() + modifier = SparseGPTModifier( + sparsity=sparsities, block_size=128, targets=targets + ) + testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) + modifier.initialize(testing_harness.get_state()) - # ensure layers names successfully match up with model - self.assertEqual(len(found_compressible_layers), len(targets)) + model = testing_harness.state.model + num_hooks = len(modifier._hooks) + num_found = sum(len(module._forward_hooks) > 0 for module in model.modules()) + self.assertEqual(num_hooks, num_found) @pytest.mark.unit @@ -68,18 +67,16 @@ def setUp(self): setup_modifier_factory() def test_create_default_quant_modifier(self): - kwargs = dict(block_size=128) - - modifier = GPTQModifier(**kwargs) - assert modifier.quantization_modifier_ is None + modifier = GPTQModifier(block_size=128) + assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - modifier.quantization_modifier_.create_init_config() + assert isinstance(modifier._quantization_modifier, QuantizationModifier) + modifier._quantization_modifier.create_init_config() default_config_group_name = "group_0" - should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[ + should_be_default_quant_scheme = modifier._quantization_modifier.config_groups[ default_config_group_name ] assert should_be_default_quant_scheme.input_activations is None @@ -106,9 +103,8 @@ def test_set_quant_if_modifer_already_exists(self): modifier.initialize(testing_harness.get_state()) assert qat_active(testing_harness.get_state().model) - kwargs = dict(block_size=128) - modifier = GPTQModifier(**kwargs) - assert not modifier.quantization_modifier_ + modifier = GPTQModifier(block_size=128) + assert not modifier._quantization_modifier modifier.on_initialize_structure(testing_harness.get_state()) # since quantization modifier is already applied, quantization must be set in @@ -142,17 +138,15 @@ def setUp(self): self.quant_config = {"QuantizationModifier": self.quant_kwargs} def test_set_quant_in_gptq(self): - kwargs = dict(block_size=128, quantize=self.quant_config) - - modifier = GPTQModifier(**kwargs) - assert modifier.quantization_modifier_ is None + modifier = GPTQModifier(block_size=128, quantize=self.quant_config) + assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - self.assertIsInstance(modifier.quantization_modifier_, QuantizationModifier) + self.assertIsInstance(modifier._quantization_modifier, QuantizationModifier) - dict_scheme = dict(modifier.quantization_modifier_.config_groups) + dict_scheme = dict(modifier._quantization_modifier.config_groups) self._check_config( dict(dict_scheme["config_group_0"].weights), self.quant_kwargs["config_groups"]["config_group_0"]["weights"], diff --git a/tests/llmcompressor/transformers/compression/run_compressed_configs/fp8_dynamic.yaml b/tests/llmcompressor/transformers/compression/run_compressed_configs/fp8_dynamic.yaml index d516616bf..926c31ec3 100644 --- a/tests/llmcompressor/transformers/compression/run_compressed_configs/fp8_dynamic.yaml +++ b/tests/llmcompressor/transformers/compression/run_compressed_configs/fp8_dynamic.yaml @@ -1,4 +1,4 @@ cadence: "commit" test_type: "regression" -model_stub: "nm-testing/tinyllama-fp8-dynamic-compressed" -empty_model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" \ No newline at end of file +compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-Dynamic-compressed +uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-Dynamic-uncompressed \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/run_compressed_configs/w4a16.yaml b/tests/llmcompressor/transformers/compression/run_compressed_configs/w4a16.yaml index 7e9bc3f2f..51d9ec25b 100644 --- a/tests/llmcompressor/transformers/compression/run_compressed_configs/w4a16.yaml +++ b/tests/llmcompressor/transformers/compression/run_compressed_configs/w4a16.yaml @@ -1,4 +1,4 @@ cadence: "commit" test_type: "regression" -model_stub: "nm-testing/tinyllama-w4a16-compressed" -empty_model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" \ No newline at end of file +compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-compressed +uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-uncompressed \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a8.yaml b/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a8.yaml index 086a67ed6..3c1646b16 100644 --- a/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a8.yaml +++ b/tests/llmcompressor/transformers/compression/run_compressed_configs/w8a8.yaml @@ -1,4 +1,4 @@ cadence: "commit" test_type: "regression" -model_stub: "nm-testing/tinyllama-w8a8-compressed" -empty_model: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" \ No newline at end of file +compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A8-Dynamic-Per-Token-compressed +uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A8-Dynamic-Per-Token-uncompressed \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/test_infer_quant_format.py b/tests/llmcompressor/transformers/compression/test_infer_quant_format.py index 7db2f0687..1eb3bf202 100644 --- a/tests/llmcompressor/transformers/compression/test_infer_quant_format.py +++ b/tests/llmcompressor/transformers/compression/test_infer_quant_format.py @@ -1,5 +1,4 @@ import pytest -from compressed_tensors.config import SparsityCompressionConfig from compressed_tensors.quantization import preset_name_to_scheme from llmcompressor.transformers.compression.quantization_format import ( @@ -20,9 +19,6 @@ ], ) def test_infer_quant_format(preset, sparsity_structure, expected_format): - sparsity_config = SparsityCompressionConfig( - format="dense", sparsity_structure=sparsity_structure - ) quant_scheme = preset_name_to_scheme(preset, targets=["Linear"]) dummy_model = LinearNet() @@ -30,6 +26,6 @@ def test_infer_quant_format(preset, sparsity_structure, expected_format): module.quantization_scheme = quant_scheme inferred_format = infer_quantization_format( - dummy_model, save_compressed=True, sparsity_config=sparsity_config + dummy_model, save_compressed=True, sparsity_structure=sparsity_structure ) assert inferred_format.value == expected_format diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py index 13eab66c9..0d34d1ca0 100644 --- a/tests/llmcompressor/transformers/compression/test_quantization.py +++ b/tests/llmcompressor/transformers/compression/test_quantization.py @@ -10,10 +10,10 @@ from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator +from llmcompressor.args import DatasetArguments from llmcompressor.pytorch.utils import tensors_to_device from llmcompressor.transformers import oneshot from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/compression/configs" @@ -74,7 +74,6 @@ def _run_oneshot(model, recipe, dataset, output_dir): ) from llmcompressor.pytorch.model_load.helpers import get_session_model - # note: get_session_model() is None outside of function scope return get_session_model() def _get_quant_info(self, model): @@ -147,7 +146,7 @@ def _get_dataloader(self, data_args, tokenizer): @torch.no_grad() def test_perplexity(self): tokenizer = AutoTokenizer.from_pretrained(self.model_stub) - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="ultrachat-200k", max_seq_length=self.max_seq_length, ) diff --git a/tests/llmcompressor/transformers/compression/test_run_compressed.py b/tests/llmcompressor/transformers/compression/test_run_compressed.py index 0c2a0ab0e..616dd0dfe 100644 --- a/tests/llmcompressor/transformers/compression/test_run_compressed.py +++ b/tests/llmcompressor/transformers/compression/test_run_compressed.py @@ -1,79 +1,133 @@ +import copy import shutil import tempfile import unittest -import torch from compressed_tensors import QUANTIZATION_CONFIG_NAME from compressed_tensors.compressors import ModelCompressor from compressed_tensors.quantization import QuantizationStatus from parameterized import parameterized_class from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers.utils.quantization_config import CompressedTensorsConfig from tests.testing_utils import parse_params, requires_gpu -CONFIG_DIR = "tests/llmcompressor/transformers/compression/run_compressed_configs" +CONFIG_DIR = "tests/llmcompressor/transformers/compression/decompression_configs" @requires_gpu @parameterized_class(parse_params(CONFIG_DIR)) -class TestQuantizationMatches(unittest.TestCase): - model_stub = None - empty_model = None +class TestDecompression(unittest.TestCase): + """ + Check that HFQuantizer decompression is working as expected. + Manually decompress a compressed model and compare the generations + + Decompression: + Given a skeleton model and path to the optimized model, + write the optimized model's safetensors to the skeleton model and decompress + Ex. write weight_scale to the skeleton model and then convert from fp4 to fp16 + + """ + + compressed_model_stub = None + skeleton_model_stub = None + + SAMPLE_INPUTS = [ + "I love 4-bit quantization because", + "What is the capital of France?", + "def fibonacci(n):", + ] @classmethod - def setUpClass(cls): - cls.test_dir = tempfile.mkdtemp() + def setUpClass(self): + self.test_dir = tempfile.mkdtemp() + self.tokenizer = AutoTokenizer.from_pretrained(self.compressed_model_stub) - # TODO: Give option on HFQuantizer to run run_compressed True/False - # currently hardcoded to True - cls.compressed_model = AutoModelForCausalLM.from_pretrained( - cls.model_stub, + # Decompress using HFQuantizer from AutoModelForCausalLM + self.decompressed_model_hf_quantizer = AutoModelForCausalLM.from_pretrained( + self.compressed_model_stub, torch_dtype="auto", device_map="auto", - # run_compressed=True, # TODO: Give option on HFQuantizer + quantization_config=CompressedTensorsConfig(run_compressed=False), ) - # TODO: Use ModelCompressor until decompression is supported through - # HFQuant/run_compressed can be turned off. - cls.uncompressed_model = AutoModelForCausalLM.from_pretrained( - cls.empty_model, - torch_dtype=cls.compressed_model.dtype, - device_map=cls.compressed_model.device, + + # Manually decompress this model + self.dense_model = AutoModelForCausalLM.from_pretrained( + self.skeleton_model_stub, + torch_dtype=self.decompressed_model_hf_quantizer.dtype, + device_map=self.decompressed_model_hf_quantizer.device, + ) + + # decompression from HFQuantizer should populate weight_scale + assert hasattr( + self.decompressed_model_hf_quantizer.model.layers[0].self_attn.q_proj, + "weight_scale", + ) + + # dense model should not have weight_scale populated + assert not hasattr( + self.dense_model.model.layers[0].self_attn.q_proj, "weight_scale" ) - config = AutoConfig.from_pretrained(cls.model_stub) + + config = AutoConfig.from_pretrained(self.compressed_model_stub) + compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) - cls.compressor = ModelCompressor.from_compression_config(compression_config) - cls.compressor.quantization_config.quantization_status = ( + self.compressor = ModelCompressor.from_compression_config(compression_config) + self.compressor.quantization_config.quantization_status = ( QuantizationStatus.FROZEN ) - cls.compressor.decompress( - model_path=cls.model_stub, model=cls.uncompressed_model + + # use the model_path to load the decompressed weights into dense_model + dense_model = copy.deepcopy(self.dense_model) + + # overwrite the weights of the dense model + self.compressor.decompress( + model_path=self.compressed_model_stub, + model=self.dense_model, ) - cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_stub) + # self.dense_model should be decompressed + assert dense_model is not self.dense_model - def test_compressed_matches_uncompressed(self): - SAMPLE_INPUT = [ - "I love 4-bit quantization because", - "What is the capital of France?", - "def fibonacci(n):", - ] + self.decompressed_model_manual = self.dense_model - inputs = self.tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to( - self.compressed_model.device + assert hasattr( + self.decompressed_model_manual.model.layers[0].self_attn.q_proj, + "weight_scale", ) - compressed_output = self.tokenizer.batch_decode( - self.compressed_model.generate(**inputs, max_length=50) + + def test_hf_quantizer_decompress_match_manual_decompress(self): + manual_device = self.decompressed_model_manual.device + decompressed_model_hf_quantizer = self.decompressed_model_hf_quantizer.device + + self.decompressed_model_manual = self.decompressed_model_manual.to( + manual_device ) - uncompressed_output = self.tokenizer.batch_decode( - self.uncompressed_model.generate(**inputs, max_length=50) + self.decompressed_model_hf_quantizer = self.decompressed_model_hf_quantizer.to( + decompressed_model_hf_quantizer ) - for idx in range(len(SAMPLE_INPUT)): - assert compressed_output[idx] == uncompressed_output[idx] + for input in self.SAMPLE_INPUTS: + inputs = self.tokenizer(input, return_tensors="pt", padding=True).to( + self.decompressed_model_manual.device + ) + inputs = inputs.to(self.decompressed_model_manual.device) + + decompressed_model_manual_output = self.tokenizer.batch_decode( + self.decompressed_model_manual.generate(**inputs, max_length=50) + ) + + decompressed_model_hf_quantizer_out = self.tokenizer.batch_decode( + self.decompressed_model_hf_quantizer.generate(**inputs, max_length=50) + ) + + assert ( + decompressed_model_hf_quantizer_out == decompressed_model_manual_output + ) @classmethod - def tearDownClass(cls): - shutil.rmtree(cls.test_dir) - del cls.compressed_model - del cls.uncompressed_model - torch.cuda.empty_cache() + def tearDownClass(self): + shutil.rmtree(self.test_dir) + del self.dense_model + del self.decompressed_model_hf_quantizer + del self.decompressed_model_manual diff --git a/tests/llmcompressor/transformers/finetune/data/conftest.py b/tests/llmcompressor/transformers/finetune/data/conftest.py index a7a347d99..aa2f056bc 100644 --- a/tests/llmcompressor/transformers/finetune/data/conftest.py +++ b/tests/llmcompressor/transformers/finetune/data/conftest.py @@ -1,7 +1,7 @@ import pytest from transformers import AutoTokenizer -from llmcompressor.transformers.finetune.model_args import ModelArguments +from llmcompressor.args import ModelArguments @pytest.fixture diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 812b26a56..7eb74f9f9 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,6 +1,6 @@ import pytest -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.args import DatasetArguments from llmcompressor.transformers.finetune.data.data_helpers import ( get_raw_dataset, make_dataset_splits, @@ -9,7 +9,7 @@ @pytest.mark.unit def test_combined_datasets(): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) raw_wikitext2 = get_raw_dataset(data_args) @@ -33,7 +33,7 @@ def test_combined_datasets(): @pytest.mark.unit def test_separate_datasets(): splits = {"train": "train[:10%]", "validation": "train[10%:20%]"} - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) datasets = {} diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index 64514b252..dcc602877 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -5,12 +5,13 @@ from datasets import IterableDataset, load_dataset from parameterized import parameterized -from llmcompressor.transformers import ( - DataTrainingArguments, +from llmcompressor.args import ( + DatasetArguments, ModelArguments, - TextGenerationDataset, + RecipeArguments, TrainingArguments, ) +from llmcompressor.transformers import TextGenerationDataset from llmcompressor.transformers.finetune.data.data_helpers import ( format_calibration_data, ) @@ -20,7 +21,7 @@ @pytest.mark.unit class TestConcentrationTokenization(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -53,7 +54,7 @@ def test_concatenation_tokenization(self): @pytest.mark.unit class TestNoPaddingTokenization(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="open_platypus", pad_to_max_length=False ) @@ -96,9 +97,7 @@ def test_no_padding_tokenization(self): @pytest.mark.unit class TestMaxSeqLenClipped(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( - dataset="open_platypus", max_seq_length=4096 - ) + self.data_args = DatasetArguments(dataset="open_platypus", max_seq_length=4096) @pytest.fixture(autouse=True) def prepare_fixture(self, tiny_llama_tokenizer): @@ -120,7 +119,7 @@ def test_max_seq_len_clipped(self): @pytest.mark.unit class TestDatasetKwargsAndPercent(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="wikitext", raw_kwargs={ "data_files": { @@ -167,7 +166,7 @@ def prepare_fixture(self, tiny_llama_tokenizer): ] ) def test_datasets(self, dataset_key, dataset_config, split, do_concat): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset=dataset_key, dataset_config_name=dataset_config, concatenate_data=do_concat, @@ -206,7 +205,7 @@ def prepare_fixture(self, tiny_llama_tokenizer): self.tiny_llama_tokenizer = tiny_llama_tokenizer def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="evolcodealpaca", dataset_config_name=None, concatenate_data=False, @@ -235,7 +234,7 @@ def test_evol(self): @pytest.mark.unit class TestStreamLoading(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -276,15 +275,19 @@ def prepare_fixture(self, tiny_llama_tokenizer): [["train"], ["train[60%:]"], [{"train": "train[:20%]"}], [None]] ) def test_split_loading(self, split_def): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="open_platypus", splits=split_def, trust_remote_code_data=True, ) training_args = TrainingArguments(do_train=True, output_dir="dummy") model_args = ModelArguments(model=None) + recipe_args = RecipeArguments() stage_runner = StageRunner( - model_args=model_args, data_args=data_args, training_args=training_args + model_args=model_args, + data_args=data_args, + training_args=training_args, + recipe_args=recipe_args, ) stage_runner.populate_datasets(processor=self.tiny_llama_tokenizer) @@ -318,10 +321,11 @@ def preprocess(sample): ) stage_runner = StageRunner( model_args=None, - data_args=DataTrainingArguments( + data_args=DatasetArguments( dataset=tokenized_dataset, shuffle_calibration_samples=False ), training_args=TrainingArguments(do_oneshot=True), + recipe_args=RecipeArguments(), ) stage_runner.populate_datasets(processor=None) calib_dataset = stage_runner.get_dataset_split("calibration") diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index 9aee4c20f..694a9b6d3 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -1,17 +1,17 @@ import pytest +from llmcompressor.args import DatasetArguments from llmcompressor.transformers.finetune.data import ( C4Dataset, OpenPlatypusDataset, TextGenerationDataset, WikiTextDataset, ) -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_c4_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="c4", concatenate_data=True) + data_args = DatasetArguments(dataset="c4", concatenate_data=True) c4_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, @@ -27,7 +27,7 @@ def test_c4_initializes(tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_wikitext_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) wiki_manager = TextGenerationDataset.load_from_registry( @@ -45,7 +45,7 @@ def test_wikitext_initializes(tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_open_platypus_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="open_platypus", pad_to_max_length=False) + data_args = DatasetArguments(dataset="open_platypus", pad_to_max_length=False) op_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, diff --git a/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml b/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml index 4f9d4293d..c814a7178 100644 --- a/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml +++ b/tests/llmcompressor/transformers/finetune/test_alternate_recipe.yaml @@ -3,10 +3,10 @@ test_oneshot_stage: SparseGPTModifier: sparsity: 0.7 block_size: 128 - sequential_update: False percdamp: 0.01 mask_structure: "0:0" - target_ids: ["attention_mask", "position_ids"] + targets: ["Linear"] + ignore: ["re:.*lm_head"] test_train_stage: pruning_modifiers: ConstantPruningModifier: diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py index e9c3d7c5c..76ea21706 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_then_finetune.py @@ -1,28 +1,23 @@ -import os import shutil import unittest from pathlib import Path import pytest +from transformers import AutoModelForCausalLM +from transformers.utils.quantization_config import CompressedTensorsConfig + +from llmcompressor.core import create_session +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.transformers import oneshot, train @pytest.mark.unit -@pytest.mark.skipif( - "CADENCE" in os.environ - and (os.environ["CADENCE"] == "weekly" or os.environ["CADENCE"] == "nightly"), - reason="Don't run for weekly and nightly tests as those use multi gpu " - "runners and this test fails when ngpu>1", -) class TestOneshotThenFinetune(unittest.TestCase): def setUp(self): self.output = Path("./finetune_output") + self.quantization_config = CompressedTensorsConfig(run_compressed=False) - def test_oneshot_then_finetune(self): - from transformers import AutoModelForCausalLM - - from llmcompressor.core import create_session - from llmcompressor.transformers import oneshot, train - + def test_oneshot_sparsification_then_finetune(self): recipe_str = "tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml" model = AutoModelForCausalLM.from_pretrained( "Xenova/llama2.c-stories15M", device_map="auto" @@ -47,8 +42,12 @@ def test_oneshot_then_finetune(self): recipe_str = ( "tests/llmcompressor/transformers/finetune/test_finetune_recipe.yaml" ) + + # Explictly decompress the model for training using quantization_config model = AutoModelForCausalLM.from_pretrained( - self.output / "oneshot_out", device_map="auto" + self.output / "oneshot_out", + device_map="auto", + quantization_config=self.quantization_config, ) distill_teacher = AutoModelForCausalLM.from_pretrained( "Xenova/llama2.c-stories15M", device_map="auto" @@ -73,7 +72,12 @@ def test_oneshot_then_finetune(self): ) # test reloading checkpoint and final model - model = AutoModelForCausalLM.from_pretrained(output_dir, device_map="auto") + # verify checkpoint reloading and can carry out finetune + # with the saved model + # Explictly decompress the model for training using quantization_config + model = AutoModelForCausalLM.from_pretrained( + output_dir, device_map="auto", quantization_config=self.quantization_config + ) with create_session(): train( model=model, @@ -88,5 +92,71 @@ def test_oneshot_then_finetune(self): resume_from_checkpoint=True, # use last checkpoint ) + def test_oneshot_quantization_then_finetune(self): + recipe = QuantizationModifier( + targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"] + ) + + model = AutoModelForCausalLM.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + device_map="auto", + ) + dataset = "open_platypus" + concatenate_data = False + num_calibration_samples = 64 + output_dir = self.output / "oneshot_out" + splits = {"calibration": "train[:10%]"} + + with create_session(): + oneshot( + model=model, + dataset=dataset, + output_dir=output_dir, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + concatenate_data=concatenate_data, + splits=splits, + ) + + from transformers.utils.quantization_config import CompressedTensorsConfig + + quantization_config = CompressedTensorsConfig(run_compressed=False) + model = AutoModelForCausalLM.from_pretrained( + output_dir, + device_map="auto", + quantization_config=quantization_config, + ) + dataset = "open_platypus" + concatenate_data = False + output_dir = self.output / "finetune_out" + splits = {"calibration": "train[:10%]", "train": "train[:10%]"} + + with create_session(): + train( + model=model, + dataset=dataset, + output_dir=output_dir, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + concatenate_data=concatenate_data, + splits=splits, + ) + + # test reloading checkpoint and final model + model = AutoModelForCausalLM.from_pretrained( + output_dir, device_map="auto", quantization_config=quantization_config + ) + with create_session(): + train( + model=model, + dataset=dataset, + output_dir=output_dir, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + concatenate_data=concatenate_data, + splits=splits, + resume_from_checkpoint=True, # use last checkpoint + ) + def tearDown(self): shutil.rmtree(self.output) diff --git a/tests/llmcompressor/transformers/finetune/test_session_mixin.py b/tests/llmcompressor/transformers/finetune/test_session_mixin.py index 69a9acd44..93bd74cd1 100644 --- a/tests/llmcompressor/transformers/finetune/test_session_mixin.py +++ b/tests/llmcompressor/transformers/finetune/test_session_mixin.py @@ -14,6 +14,8 @@ def __init__( model: Module, recipe: Optional[str], recipe_args: Optional[Union[Dict[str, Any], str]] = None, + model_args: Optional[Union[Dict[str, Any], str]] = None, + data_args: Optional[Union[Dict[str, Any], str]] = None, teacher: Optional[Union[Module, str]] = None, **kwargs, ): @@ -21,6 +23,8 @@ def __init__( model=model, recipe=recipe, recipe_args=recipe_args, + model_args=model_args, + data_args=data_args, teacher=teacher, **kwargs, ) diff --git a/tests/llmcompressor/transformers/gptq/test_oneshot.py b/tests/llmcompressor/transformers/gptq/test_oneshot.py index 7f1a1ec99..c391890b2 100644 --- a/tests/llmcompressor/transformers/gptq/test_oneshot.py +++ b/tests/llmcompressor/transformers/gptq/test_oneshot.py @@ -74,8 +74,8 @@ def test_oneshot_application(self): oneshot( model=self.model, dataset=self.dataset, - output_dir=self.output, overwrite_output_dir=True, + output_dir=self.output, recipe=self.recipe, oneshot_device=self.device, num_calibration_samples=9, diff --git a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml index 64ce30250..d96e97f91 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity.yaml @@ -3,7 +3,6 @@ test_stage: SparseGPTModifier: sparsity: 0.7 block_size: 128 - sequential_update: True percdamp: 0.01 mask_structure: "0:0" targets: ["model.layers.0"] diff --git a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml index 027c56363..99d89925c 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/additional_sparsity_with_quant.yaml @@ -11,7 +11,6 @@ test_stage: SparseGPTModifier: sparsity: 0.7 block_size: 128 - sequential_update: False percdamp: 0.01 mask_structure: "0:0" targets: [ diff --git a/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml b/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml index eb02ea81d..102ad768e 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/quant_and_sparse.yaml @@ -18,7 +18,6 @@ test_stage: SparseGPTModifier: sparsity: 0.5 block_size: 128 - sequential_update: False percdamp: 0.01 mask_structure: "0:0" targets: ["model.layers.0"] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml b/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml index d485064fa..d5cf54a6e 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/sparse.yaml @@ -3,7 +3,6 @@ test_stage: SparseGPTModifier: sparsity: 0.3 block_size: 128 - sequential_update: False percdamp: 0.01 targets: ["model.layers.0", "model.layers.1"] mask_structure: "0:0" \ No newline at end of file diff --git a/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml b/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml index 20c4c9397..13c41c33e 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/sparse_with_mask_structure.yaml @@ -3,7 +3,6 @@ test_stage: SparseGPTModifier: sparsity: 0.5 block_size: 128 - sequential_update: False percdamp: 0.01 mask_structure: "2:4" targets: [ diff --git a/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml b/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml index 8a97ff733..dee90a07d 100644 --- a/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml +++ b/tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml @@ -3,7 +3,6 @@ test_stage: SparseGPTModifier: sparsity: 0.5 block_size: 128 - sequential_update: False percdamp: 0.01 mask_structure: "0:0" targets: [ diff --git a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py index 2f6c51ebb..16c9003be 100644 --- a/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py +++ b/tests/llmcompressor/transformers/obcq/test_consecutive_runs.py @@ -5,7 +5,11 @@ import pytest import yaml from parameterized import parameterized_class +from transformers import AutoModelForCausalLM +from transformers.utils.quantization_config import CompressedTensorsConfig +from llmcompressor.transformers.utils import is_model_ct_quantized_from_path +from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/obcq/obcq_configs/consec_runs" @@ -15,13 +19,15 @@ class TestConsecutiveRuns(unittest.TestCase): + quantization_config = CompressedTensorsConfig(run_compressed=False) + def _test_consecutive_runs( self, tolerance: float, num_calibration_samples: int = 16 ): import math from llmcompressor.core import active_session - from llmcompressor.pytorch.model_load.helpers import get_session_model + from llmcompressor.pytorch.model_load.helpers import initialize_recipe from llmcompressor.pytorch.utils.helpers import tensor_sparsity from llmcompressor.transformers import oneshot from llmcompressor.utils.pytorch import qat_active @@ -36,12 +42,18 @@ def _test_consecutive_runs( oneshot_device=self.device, clear_sparse_session=False, ) - first_tiny_model = get_session_model() + + first_model = AutoModelForCausalLM.from_pretrained( + self.output_first, + device_map="auto", + quantization_config=self.quantization_config, + ) + layer_0_sparse = tensor_sparsity( - first_tiny_model.model.layers[0].self_attn.k_proj.weight + first_model.model.layers[0].self_attn.k_proj.weight ) assert math.isclose(layer_0_sparse.item(), 0.5, rel_tol=tolerance) - assert qat_active(first_tiny_model) + assert qat_active(first_model) session = active_session() session_recipe = session.lifecycle.recipe_container.compiled_recipe @@ -49,6 +61,10 @@ def _test_consecutive_runs( self.assertEqual(len(stages), 1) session.reset() + recipe = infer_recipe_from_model_path(model_path=self.output_first) + if recipe: + initialize_recipe(model=first_model, recipe_path=recipe) + # reload saved model and up sparsity to 0.7 oneshot( model=self.output_first, @@ -57,15 +73,19 @@ def _test_consecutive_runs( recipe=self.second_recipe, output_dir=self.output_second, oneshot_device=self.device, - clear_sparse_session=False, ) - second_tiny_model = get_session_model() + second_model = AutoModelForCausalLM.from_pretrained( + self.output_second, + device_map="auto", + quantization_config=self.quantization_config, + ) + layer_0_sparse = tensor_sparsity( - second_tiny_model.model.layers[0].self_attn.k_proj.weight + second_model.model.layers[0].self_attn.k_proj.weight ) assert math.isclose(layer_0_sparse.item(), 0.7, rel_tol=tolerance) - assert qat_active(second_tiny_model) + assert qat_active(second_model) session = active_session() session_recipe = session.lifecycle.recipe_container.compiled_recipe @@ -118,8 +138,14 @@ class TestConsecutiveRunsGPU(TestConsecutiveRuns): def setUp(self): from transformers import AutoModelForCausalLM + self.assertFalse( + is_model_ct_quantized_from_path(self.model), + "The provided model is quantized. Please use a dense model.", + ) + self.model = AutoModelForCausalLM.from_pretrained( - self.model, device_map=self.device + self.model, + device_map=self.device, ) self.output = "./oneshot_output" diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py index fe699570a..e4974a956 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py @@ -22,13 +22,11 @@ def labeled_dataloader(self, dataset_name, model_name): from torch.utils.data import DataLoader from transformers import AutoTokenizer, DefaultDataCollator + from llmcompressor.args import DatasetArguments from llmcompressor.transformers.finetune.data import TextGenerationDataset - from llmcompressor.transformers.finetune.data.data_args import ( - DataTrainingArguments, - ) tokenizer = AutoTokenizer.from_pretrained(model_name) - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset=dataset_name, max_seq_length=512, pad_to_max_length=False, diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py b/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py index 483a65f2d..a53bec558 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py @@ -1,27 +1,15 @@ -import unittest - import pytest +from accelerate import init_empty_weights +from transformers import AutoModelForCausalLM -from llmcompressor.utils.pytorch.module import get_no_split_params +from llmcompressor.modifiers.pruning import SparseGPTModifier @pytest.mark.integration -class TestInferTargets(unittest.TestCase): - def setUp(self): - from transformers import AutoModelForCausalLM - +def test_infer_targets(): + modifier = SparseGPTModifier(sparsity=0.0) + with init_empty_weights(): model = AutoModelForCausalLM.from_pretrained("Xenova/llama2.c-stories15M") - self.modifiable_model = model - self.targets = get_no_split_params(self.modifiable_model) - - def test_infer_targets(self): - from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier - - self.assertEqual(len(self.targets), 1) - self.assertEqual(self.targets[0], "LlamaDecoderLayer") - modifier = SparseGPTModifier(sparsity=0.5) - modifier.targets = self.targets - modifier.model = self.modifiable_model - compressible_layers = modifier.compressible_layers() - self.assertEqual(len(compressible_layers), 6) + inferred = modifier._infer_sequential_targets(model) + assert inferred == ["LlamaDecoderLayer"] diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py b/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py index ddb6f41ff..7a847ba1e 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py @@ -1,7 +1,11 @@ import unittest +from unittest.mock import MagicMock import pytest +from llmcompressor.core.state import State +from llmcompressor.modifiers.pruning import SparseGPTModifier + @pytest.mark.integration class TestLMHead(unittest.TestCase): @@ -14,6 +18,7 @@ def setUp(self): self.model = AutoModelForCausalLM.from_pretrained( "Xenova/llama2.c-stories15M", device_map=self.device ) + self.kwargs = { "sparsity": 0.5, "block_size": 128, @@ -28,21 +33,31 @@ def setUp(self): ], } - def test_lm_head_target(self): - from llmcompressor.core.state import State - from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier + dataset = MagicMock() + dataset.column_names = [] + self.dataloader = MagicMock() + self.dataloader.dataset = dataset + self.dataloader.__iter__.return_value = iter([]) - sparsegpt_modifier_no_head = SparseGPTModifier(**self.kwargs) + def test_no_lm_head_target(self): + modifier = SparseGPTModifier(**self.kwargs) state = State() - state.update(model=self.model, device=self.device) - sparsegpt_modifier_no_head.initialize_compression(state.model) + state.update(model=self.model, device=self.device, calib_data=self.dataloader) + modifier.on_initialize(state) + + assert len(self.model.lm_head._forward_hooks) <= 0 + + modifier.finalize(state) + def test_lm_head_target(self): self.kwargs["targets"].append("lm_head") - sparsegpt_modifier_head = SparseGPTModifier(**self.kwargs) - sparsegpt_modifier_head.initialize_compression(state.model) + modifier = SparseGPTModifier(**self.kwargs) + + state = State() + state.update(model=self.model, device=self.device, calib_data=self.dataloader) + modifier.on_initialize(state) + + assert len(self.model.lm_head._forward_hooks) == 1 - # check we pick up the lm_head layer - layers_no_head = len(sparsegpt_modifier_no_head.compressible_layers_) - layers_head = len(sparsegpt_modifier_head.compressible_layers_) - self.assertEqual(layers_head, layers_no_head + 1) + modifier.finalize(state) diff --git a/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py b/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py deleted file mode 100644 index d4d0ba280..000000000 --- a/tests/llmcompressor/transformers/obcq/test_sgpt_defaults.py +++ /dev/null @@ -1,23 +0,0 @@ -import unittest - -import pytest - - -@pytest.mark.integration -class TestSGPTDefaults(unittest.TestCase): - def test_sgpt_defaults(self): - from llmcompressor.core.state import State - from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier - - kwargs = {"sparsity": 0.5} - sparsegpt_modifier_only_sparsity = SparseGPTModifier(**kwargs) - self.assertEqual(sparsegpt_modifier_only_sparsity.block_size, 128) - self.assertEqual(sparsegpt_modifier_only_sparsity.sparsity, 0.5) - - # fail if we don't pass a sparsity or enable quantization - kwargs = {} - sparsegpt_invalid = SparseGPTModifier(**kwargs) - state_test = State() - sparsegpt_invalid.initialized_structure_ = True - with self.assertRaises(ValueError): - sparsegpt_invalid.on_initialize(state=state_test) diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml index b9aa59e06..7704d5718 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml @@ -3,7 +3,6 @@ test_stage: SparseGPTModifier: sparsity: 0.5 block_size: 128 - sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml index b4f61ff9f..eee5711bc 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml @@ -9,7 +9,6 @@ recipe: | SparseGPTModifier: sparsity: 0.5 block_size: 128 - sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml index 6443c09c7..7ccfd5d80 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf4.yaml @@ -10,7 +10,6 @@ recipe: | SparseGPTModifier: sparsity: 0.5 block_size: 128 - sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/oneshot/test_cli.py b/tests/llmcompressor/transformers/oneshot/test_cli.py index 5780ca46f..08273b367 100644 --- a/tests/llmcompressor/transformers/oneshot/test_cli.py +++ b/tests/llmcompressor/transformers/oneshot/test_cli.py @@ -49,6 +49,7 @@ def test_one_shot_cli(self): if len(self.additional_args) > 0: cmd.extend(self.additional_args) res = run_cli_command(cmd) + self.assertEqual(res.returncode, 0) print(res.stdout) diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index df9726647..eeb6e95ae 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -6,12 +6,18 @@ import torch from accelerate import cpu_offload from accelerate.accelerator import get_state_dict_offloaded_model -from compressed_tensors import QUANTIZATION_CONFIG_NAME +from compressed_tensors import QUANTIZATION_CONFIG_NAME, CompressionFormat from compressed_tensors.compressors import ModelCompressor from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig -from compressed_tensors.quantization import QuantizationStatus +from compressed_tensors.quantization import ( + QuantizationConfig, + QuantizationStatus, + quantize, +) from compressed_tensors.utils import get_offloaded_device, update_prefix_dict +from torch import nn from transformers import AutoConfig, AutoModelForCausalLM +from transformers.utils.quantization_config import CompressedTensorsConfig from llmcompressor.core import reset_session from llmcompressor.pytorch.utils.helpers import tensor_sparsity @@ -20,6 +26,7 @@ SparsityConfigMetadata, ) from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( + get_model_compressor, modify_save_pretrained, patch_tied_tensors_bug, ) @@ -171,9 +178,8 @@ def test_quant_model_reload(format, dtype, tmp_path): device = "cpu" dataset = "open_platypus" concatenate_data = False - num_calibration_samples = 64 + num_calibration_samples = 16 splits = {"calibration": "train[:10%]"} - empty_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype) # create a quantized model oneshot( @@ -191,7 +197,7 @@ def test_quant_model_reload(format, dtype, tmp_path): # Fetch the oneshot model model = get_session_model() og_state_dict = model.state_dict() - path = tmp_path / "compressed" + save_path_compressed = tmp_path / "compressed" for _, module in model.named_modules(): if hasattr(module, "quantization_scheme"): @@ -200,32 +206,24 @@ def test_quant_model_reload(format, dtype, tmp_path): # Save to disk model.save_pretrained( - path, + save_path_compressed, quantization_format=format, save_compressed=True, ) # Verify config on disk - config = AutoConfig.from_pretrained(path) + config = AutoConfig.from_pretrained(save_path_compressed) compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) quant_config = ModelCompressor.parse_quantization_config(compression_config) assert quant_config["format"] == format - # As HFQuantizer doesn't decompress the model, use the compressor to decompress - # the model instead - compressor = ModelCompressor.from_compression_config(compression_config) - compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN - compressor.decompress(model_path=path, model=empty_model) - - # eventually use this pathway once HFQuant Decompression works - """ - dense_model = SparseAutoModelForCausalLM.from_pretrained( - "compress_out", torch_dtype="auto", device_map=device + decompressed_model = AutoModelForCausalLM.from_pretrained( + save_path_compressed, + torch_dtype=dtype, + quantization_config=CompressedTensorsConfig(run_compressed=False), ) - """ - # Verify the abs difference between the decompressed model - # and the original model - reconstructed_state_dict = empty_model.state_dict() + + reconstructed_state_dict = decompressed_model.state_dict() assert len(og_state_dict) == len(reconstructed_state_dict) for key in og_state_dict.keys(): dense_tensor = og_state_dict[key].to(device) @@ -364,3 +362,346 @@ def test_model_shared_tensors_gpu( test_model_shared_tensors( offload, torch_dtype, tie_word_embeddings, device_map, tmp_path ) + + +@pytest.mark.parametrize( + "model_stub, recipe, sparse_format, quant_format", + [ + ( + "Xenova/llama2.c-stories15M", + "tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml", + CompressionFormat.sparse_24_bitmask.value, + CompressionFormat.float_quantized.value, + ), + ], +) +def test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tmp_path): + from llmcompressor.pytorch.model_load.helpers import get_session_model + + device = "cuda" + if not torch.cuda.is_available(): + device = "cpu" + dataset = "open_platypus" + concatenate_data = False + num_calibration_samples = 64 + splits = {"calibration": "train[:10%]"} + empty_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype="auto") + + oneshot( + model=model_stub, + dataset=dataset, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + concatenate_data=concatenate_data, + splits=splits, + oneshot_device=device, + clear_sparse_session=False, + ) + + # Fetch the oneshot model + model = get_session_model() + og_state_dict = model.state_dict() + path = tmp_path / "compressed" + + # Compress and save + model.save_pretrained( + path, + quantization_format=quant_format, + save_compressed=True, + ) + + # Verify config on disk + config = AutoConfig.from_pretrained(path) + compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) + quant_config = ModelCompressor.parse_quantization_config(compression_config) + + # As HFQuantizer doesn't decompress the model, use the compressor to decompress + # the model instead + compressor = ModelCompressor.from_compression_config(compression_config) + + assert ( + compressor.sparsity_compressor is not None + ), "Sparse compressor not initialized" + assert compressor.sparsity_config.format == sparse_format + + assert ( + compressor.quantization_compressor is not None + ), "Quantization compressor not initialized" + assert quant_config["format"] == quant_format + + compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN + compressor.decompress(model_path=path, model=empty_model) + + # Verify the abs difference between the decompressed model + # and the original model + reconstructed_state_dict = empty_model.state_dict() + assert len(og_state_dict) == len(reconstructed_state_dict) + for key in og_state_dict.keys(): + dense_tensor = og_state_dict[key].to(device) + reconstructed_tensor = reconstructed_state_dict[key].to(device) + assert dense_tensor.dtype == reconstructed_tensor.dtype + if key.endswith("weight") and quant_format != "dense": + # we don't expect an exact match for compressed + diff = torch.abs(dense_tensor - reconstructed_tensor) + # max diff value found empirically + assert not torch.any(diff > 0.022), f"Max diff: {torch.max(diff)}" + else: + assert torch.equal(dense_tensor, reconstructed_tensor) + shutil.rmtree(tmp_path) + + +@pytest.mark.parametrize( + "model_stub, recipe, sparse_format", + [ + ( + "Xenova/llama2.c-stories15M", + "tests/llmcompressor/transformers/compression/recipes/sparse_24.yaml", + CompressionFormat.sparse_24_bitmask.value, + ), + ], +) +def test_sparse_24_compressor_is_lossless(model_stub, recipe, sparse_format, tmp_path): + from llmcompressor.pytorch.model_load.helpers import get_session_model + + device = "cuda" + if not torch.cuda.is_available(): + device = "cpu" + dataset = "open_platypus" + concatenate_data = False + num_calibration_samples = 64 + splits = {"calibration": "train[:10%]"} + empty_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype="auto") + + oneshot( + model=model_stub, + dataset=dataset, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + concatenate_data=concatenate_data, + splits=splits, + oneshot_device=device, + clear_sparse_session=False, + ) + + # Fetch the oneshot model + model = get_session_model() + og_state_dict = model.state_dict() + path = tmp_path / "compressed" + + # Compress and save + model.save_pretrained( + path, + save_compressed=True, + ) + + # Verify config on disk + config = AutoConfig.from_pretrained(path) + compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) + + # As HFQuantizer doesn't decompress the model, use the compressor to decompress + # the model instead + compressor = ModelCompressor.from_compression_config(compression_config) + + assert ( + compressor.sparsity_compressor is not None + ), "Sparse compressor not initialized" + assert compressor.sparsity_config.format == sparse_format + + compressor.decompress(model_path=path, model=empty_model) + + # Verify the abs difference between the decompressed model + # and the original model + reconstructed_state_dict = empty_model.state_dict() + assert len(og_state_dict) == len(reconstructed_state_dict) + for key in og_state_dict.keys(): + dense_tensor = og_state_dict[key].to(device) + reconstructed_tensor = reconstructed_state_dict[key].to(device) + assert dense_tensor.dtype == reconstructed_tensor.dtype + if key.endswith("weight"): + assert torch.equal(dense_tensor, reconstructed_tensor) + shutil.rmtree(tmp_path) + + +def test_disable_sparse_compression_flag(tmp_path): + two_four_sparse_model_id = "nm-testing/llama2.c-stories42M-pruned2.4" + two_four_sparse_model = AutoModelForCausalLM.from_pretrained( + two_four_sparse_model_id, torch_dtype="auto" + ) + modify_save_pretrained(two_four_sparse_model) + + save_path = tmp_path / "no_sparse_compression_model" + two_four_sparse_model.save_pretrained(save_path, disable_sparse_compression=True) + + config = AutoConfig.from_pretrained(save_path) + quantization_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) + + assert quantization_config + sparsity_config = quantization_config.get("sparsity_config") + + assert sparsity_config + assert sparsity_config["format"] == "dense" + shutil.rmtree(tmp_path) + + +class DummyLinearModel(nn.Module): + """ + A dummy linear model for testing purposes, simulating a quantized linear layer. + """ + + def __init__(self, weights, weight_scale=None, weight_zero_point=None): + super().__init__() + out_features, in_features = weights.shape + + # Linear layer without bias + self.linear = nn.Linear(in_features, out_features, bias=False) + self.linear.weight = nn.Parameter(weights, requires_grad=False) + + # Attach scale and zero-point if provided + if weight_scale is not None: + self.linear.weight_scale = nn.Parameter( + torch.tensor(weight_scale), requires_grad=False + ) + if weight_zero_point is not None: + self.linear.weight_zero_point = nn.Parameter( + torch.tensor(weight_zero_point), requires_grad=False + ) + + def forward(self, x): + return self.linear(x) + + +def _create_quantization_config( + w_bits=8, + w_type="int", + w_strategy="tensor", + quantize_activations=False, + a_bits=8, + a_type="int", + a_strategy="tensor", +): + """ + Create a quantization configuration for testing. + """ + config_dict = { + "global_compression_ratio": 1.0, + "quant_method": "compressed-tensors", + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": w_bits, + "strategy": w_strategy, + "symmetric": True, + "type": w_type, + }, + } + }, + } + + if quantize_activations: + config_dict["config_groups"]["group_0"]["input_activations"] = { + "num_bits": a_bits, + "strategy": a_strategy, + "symmetric": True, + "type": a_type, + } + + return QuantizationConfig.model_validate(config_dict) + + +def _quantization_config_from_string(config_str, q_type): + """ + Parse quantization config from string and type. + """ + w_bits = int(config_str[1]) + a_bits = int(config_str[3:]) + quantize_activations = a_bits < 16 + + return _create_quantization_config( + w_bits=w_bits, + w_type=q_type, + w_strategy="channel", + quantize_activations=quantize_activations, + a_bits=a_bits, + a_type=q_type, + a_strategy="channel", + ) + + +def _make_24_sparse(tensor): + """ + Apply 2:4 sparsity pattern to the given tensor. + """ + reshaped_tensor = tensor.view(tensor.size(0), -1, 4) + mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool) + mask[..., :2] = True + sparsified_tensor = torch.where( + mask, reshaped_tensor, torch.tensor(0.0, dtype=tensor.dtype) + ) + return sparsified_tensor.view_as(tensor) + + +@pytest.mark.parametrize( + "quant_style, quant_type, is_24, expected_quant_compressor, " + "expected_sparsity_compressor", + [ + ("W8A8", "int", False, "int-quantized", "dense"), + ("W4A16", "int", False, "pack-quantized", "dense"), + ("W8A16", "int", False, "pack-quantized", "dense"), + ("W8A8", "int", True, "int-quantized", "sparse-24-bitmask"), + ("W4A16", "int", True, "marlin-24", "dense"), + ("W8A16", "int", True, "marlin-24", "dense"), + ("W8A8", "float", False, "float-quantized", "dense"), + ("W8A16", "float", False, "naive-quantized", "dense"), + ("W8A8", "float", True, "float-quantized", "sparse-24-bitmask"), + ("W8A16", "float", True, "naive-quantized", "dense"), + ], +) +def test_correct_compressor_inferred( + quant_style, + quant_type, + is_24, + expected_quant_compressor, + expected_sparsity_compressor, +): + """ + Test if the correct compressor is inferred based on + quantization and sparsity configurations. + """ + weights = torch.rand(10, 4) + if is_24: + weights = _make_24_sparse(weights) + else: + weights[0, :] = torch.ones( + 4, + ) # guarantee not 24 sparse + + quantization_config = _quantization_config_from_string(quant_style, quant_type) + quantization_args = quantization_config.config_groups["group_0"].weights + + scale = ( + torch.ones((weights.shape[0], 1)) + if quantization_args.strategy == "channel" + else torch.tensor([1.0]) + ) + zero_point = torch.zeros_like(scale) + + quantized_weights = quantize( + weights, scale=scale, zero_point=zero_point, args=quantization_args + ) + + model = DummyLinearModel(quantized_weights, scale, zero_point) + model.linear.quantization_scheme = quantization_config.config_groups["group_0"] + model.linear.quantization_status = QuantizationStatus.FROZEN + + compressor = get_model_compressor(model) + + assert compressor.quantization_config.format == expected_quant_compressor + + if expected_sparsity_compressor == "dense": + assert ( + compressor.sparsity_config is None + or compressor.sparsity_config.format == expected_sparsity_compressor + ) + else: + assert compressor.sparsity_config.format == expected_sparsity_compressor diff --git a/tests/testing_utils.py b/tests/testing_utils.py index a6103a73c..257506784 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -135,7 +135,8 @@ def preprocess_tokenize_dataset( :param tokenizer: tokenizer to be used for tokenization :param max_seq_length: maximum sequence length of samples """ - if ds.info.dataset_name == "gsm8k": + ds_name = ds.info.dataset_name.lower() + if ds_name == "gsm8k": def preprocess(example): return example @@ -148,7 +149,8 @@ def tokenize(sample): truncation=True, add_special_tokens=False, ) - elif ds.info.dataset_name == "ultrachat_200k": + + elif ds_name == "ultrachat_200k": def preprocess(example): return { @@ -166,6 +168,69 @@ def tokenize(sample): truncation=True, add_special_tokens=False, ) + + elif ds_name == "llm_compression_calibration": + + def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["text"], + tokenize=False, + ) + } + + def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=max_seq_length, + truncation=True, + add_special_tokens=False, + ) + + elif ds_name == "open-platypus": + # use the output rather than the instruction + def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["output"], + tokenize=False, + ) + } + + def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=max_seq_length, + truncation=True, + add_special_tokens=False, + ) + + elif ds_name == "slimorca-deduped-cleaned-corrected": + # find the first element corresponding to a message from a human + def preprocess(example): + conversation_idx = 0 + for idx, conversation in enumerate(example["conversations"]): + if conversation["from"] == "human": + conversation_idx = idx + break + return { + "text": tokenizer.apply_chat_template( + example["conversations"][conversation_idx]["value"], + tokenize=False, + ) + } + + def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=max_seq_length, + truncation=True, + add_special_tokens=False, + ) + else: raise NotImplementedError(f"Cannot preprocess dataset {ds.info.dataset_name}")