Skip to content

Commit

Permalink
[Test Fix] Fix Consecutive oneshot (#971)
Browse files Browse the repository at this point in the history
~~Contingent on merge of
huggingface/transformers#34719
~~ ^ has been merged not yet released ~~
^ has been released


Blocked on 
neuralmagic/compressed-tensors#237

SUMMARY:
* In multiple optimization tests, automatically decompress model if
provided as optimized model
* Fix recipe stage length
* Revive old code
* When running multiple optimizations (ex. oneshot then finetune,
oneshot and oneshot), the recipes needs to be added to the session using
`initialize_recipe`. Example here
https://github.com/vllm-project/llm-compressor/pull/971/files#diff-c9ae8b3ad24d13abeea5b649a5fd6d0b0925f5c9cc40220cbfbe21ae81242f8dR63-R65


TEST PLAN:
ran the test using transformers main
Must pass tests/llmcompressor/transformers/obcq/test_consecutive_runs.py

---------

Co-authored-by: Dipika Sikka <[email protected]>
Co-authored-by: Rahul Tuli <[email protected]>
  • Loading branch information
3 people authored Jan 23, 2025
1 parent 7610854 commit 0b6782a
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 11 deletions.
13 changes: 12 additions & 1 deletion src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PreTrainedModel,
set_seed,
)
from transformers.utils.quantization_config import CompressedTensorsConfig

from llmcompressor.core import pre_initialize_structure, reset_session
from llmcompressor.pytorch.model_load.helpers import (
Expand All @@ -52,7 +53,10 @@
from llmcompressor.transformers.sparsification.sparse_model import (
get_shared_processor_src,
)
from llmcompressor.transformers.utils.helpers import detect_last_checkpoint
from llmcompressor.transformers.utils.helpers import (
detect_last_checkpoint,
is_model_ct_quantized_from_path,
)
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model

Expand Down Expand Up @@ -224,6 +228,13 @@ def initialize_model_from_path(
"trust_remote_code": model_args.trust_remote_code_model,
}
# this calls from_pretrained under the hood so should be FSDP safe

# optimized models must be decompressed to carry out oneshot/train/etc
if is_model_ct_quantized_from_path(model_path):
model_kwargs["quantization_config"] = CompressedTensorsConfig(
run_compressed=False
)

model = AutoModelForCausalLM.from_pretrained(
model_path,
**model_kwargs,
Expand Down
105 changes: 104 additions & 1 deletion src/llmcompressor/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
"""

import os
from typing import TYPE_CHECKING, Optional
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union

import requests
from huggingface_hub import HUGGINGFACE_CO_URL_HOME, hf_hub_download
from loguru import logger
from transformers import AutoConfig
from transformers.trainer_utils import get_last_checkpoint

if TYPE_CHECKING:
Expand All @@ -15,6 +19,7 @@
__all__ = [
"RECIPE_FILE_NAME",
"detect_last_checkpoint",
"is_model_ct_quantized_from_path",
]

RECIPE_FILE_NAME = "recipe.yaml"
Expand Down Expand Up @@ -54,3 +59,101 @@ def detect_last_checkpoint(
)

return last_checkpoint


def is_model_ct_quantized_from_path(path: str) -> bool:
"""
Determine if model from path is quantized based
on the config
:param path: path to the model or HF stub
:return: True if config contains quantization_config from the given path
"""
config = AutoConfig.from_pretrained(path)
if config is not None:
if (
hasattr(config, "quantization_config")
and config.quantization_config["quant_method"] == "compressed-tensors"
):
return True
return False


def infer_recipe_from_model_path(model_path: Union[str, Path]) -> Optional[str]:
"""
Infer the recipe from the model_path.
:param model_path: The path to the model to load. It can be one of the following:
- a path to the model directory
- a path to the model file
- Hugging face model ID
:return: The path to the recipe file if found, None otherwise.
"""
model_path = model_path.as_posix() if isinstance(model_path, Path) else model_path

if os.path.isdir(model_path) or os.path.isfile(model_path):
# Model path is a local path to the model directory or file
model_path = (
os.path.dirname(model_path) if os.path.isfile(model_path) else model_path
)
recipe = os.path.join(model_path, RECIPE_FILE_NAME)

if os.path.isfile(recipe):
logger.info(f"Found recipe in the model_path: {recipe}")
return recipe
logger.debug(f"No recipe found in the model_path: {model_path}")
return None

# If the model path is a Hugging Face model ID
recipe = recipe_from_huggingface_model_id(hf_stub=model_path)

if recipe is None:
logger.info("Failed to infer the recipe from the model_path")

return recipe


def recipe_from_huggingface_model_id(
hf_stub: str, recipe_file_name: str = RECIPE_FILE_NAME
) -> Optional[str]:
"""
Attempts to download the recipe from the Hugging Face model ID.
:param hf_stub: Assumed to be the Hugging Face model ID.
:param recipe_file_name: The name of the recipe file to download.
Defaults to RECIPE_FILE_NAME.
:return: A tuple:
- The path to the recipe file if found, None otherwise.
- True if hf_stub is a valid Hugging Face model ID, False otherwise.
"""
model_id_url = os.path.join(HUGGINGFACE_CO_URL_HOME, hf_stub)
request = requests.head(model_id_url)

if request.status_code != 200:
logger.debug(
(
"hf_stub is not a valid Hugging Face model ID. ",
"Skipping recipe resolution.",
)
)
return None

try:
logger.info(
"Attempting to download a recipe ",
f"{hf_stub} " f"from {HUGGINGFACE_CO_URL_HOME}",
)
recipe = hf_hub_download(repo_id=hf_stub, filename=recipe_file_name)
logger.info(f"Found recipe: {recipe_file_name} for model ID: {hf_stub}.")
except Exception as e:
logger.error(
(
f"Unable to find recipe {recipe_file_name} "
f"for model ID: {hf_stub}: {e}."
"Skipping recipe resolution."
)
)
recipe = None

return recipe
39 changes: 30 additions & 9 deletions tests/llmcompressor/transformers/obcq/test_consecutive_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import pytest
import yaml
from parameterized import parameterized_class
from transformers import AutoModelForCausalLM
from transformers.utils.quantization_config import CompressedTensorsConfig

from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path
from tests.testing_utils import parse_params, requires_gpu

CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/obcq/obcq_configs/consec_runs"
Expand All @@ -15,13 +18,15 @@


class TestConsecutiveRuns(unittest.TestCase):
quantization_config = CompressedTensorsConfig(run_compressed=False)

def _test_consecutive_runs(
self, tolerance: float, num_calibration_samples: int = 16
):
import math

from llmcompressor.core import active_session
from llmcompressor.pytorch.model_load.helpers import get_session_model
from llmcompressor.pytorch.model_load.helpers import initialize_recipe
from llmcompressor.pytorch.utils.helpers import tensor_sparsity
from llmcompressor.transformers import oneshot
from llmcompressor.utils.pytorch import qat_active
Expand All @@ -36,19 +41,29 @@ def _test_consecutive_runs(
oneshot_device=self.device,
clear_sparse_session=False,
)
first_tiny_model = get_session_model()

first_model = AutoModelForCausalLM.from_pretrained(
self.output_first,
device_map="auto",
quantization_config=self.quantization_config,
)

layer_0_sparse = tensor_sparsity(
first_tiny_model.model.layers[0].self_attn.k_proj.weight
first_model.model.layers[0].self_attn.k_proj.weight
)
assert math.isclose(layer_0_sparse.item(), 0.5, rel_tol=tolerance)
assert qat_active(first_tiny_model)
assert qat_active(first_model)

session = active_session()
session_recipe = session.lifecycle.recipe_container.compiled_recipe
stages = [stage.group for stage in session_recipe.stages]
self.assertEqual(len(stages), 1)
session.reset()

recipe = infer_recipe_from_model_path(model_path=self.output_first)
if recipe:
initialize_recipe(model=first_model, recipe_path=recipe)

# reload saved model and up sparsity to 0.7
oneshot(
model=self.output_first,
Expand All @@ -57,15 +72,19 @@ def _test_consecutive_runs(
recipe=self.second_recipe,
output_dir=self.output_second,
oneshot_device=self.device,
clear_sparse_session=False,
)

second_tiny_model = get_session_model()
second_model = AutoModelForCausalLM.from_pretrained(
self.output_second,
device_map="auto",
quantization_config=self.quantization_config,
)

layer_0_sparse = tensor_sparsity(
second_tiny_model.model.layers[0].self_attn.k_proj.weight
second_model.model.layers[0].self_attn.k_proj.weight
)
assert math.isclose(layer_0_sparse.item(), 0.7, rel_tol=tolerance)
assert qat_active(second_tiny_model)
assert qat_active(second_model)

session = active_session()
session_recipe = session.lifecycle.recipe_container.compiled_recipe
Expand Down Expand Up @@ -119,7 +138,9 @@ def setUp(self):
from transformers import AutoModelForCausalLM

self.model = AutoModelForCausalLM.from_pretrained(
self.model, device_map=self.device
self.model,
device_map=self.device,
quantization_config=self.quantization_config,
)

self.output = "./oneshot_output"
Expand Down

0 comments on commit 0b6782a

Please sign in to comment.