Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Test Fix] Add Quantization then finetune tests #964

Merged
merged 27 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
92fbddc
add quantization then finetune -- run_compressed=False
horheynm Dec 9, 2024
299eed3
add test
horheynm Dec 9, 2024
aebae9a
Merge branch 'main' into quant-then-finetune
horheynm Dec 9, 2024
9ea94ed
clean up
horheynm Dec 9, 2024
a264fc0
Merge branch 'main' into quant-then-finetune
horheynm Dec 11, 2024
ee4c70d
comments
horheynm Dec 23, 2024
0d32d23
Merge branch 'main' into quant-then-finetune
horheynm Dec 23, 2024
72b5431
Merge branch 'main' into quant-then-finetune
horheynm Jan 9, 2025
8696be2
add clarity on loading ckpt and carrying out finetune on saved model
horheynm Jan 10, 2025
cd22e88
Merge branch 'quant-then-finetune' of github.com:vllm-project/llm-com…
horheynm Jan 10, 2025
fa5f3ff
Merge branch 'main' into quant-then-finetune
horheynm Jan 10, 2025
ff318f9
Merge branch 'main' into quant-then-finetune
horheynm Jan 10, 2025
8ebf898
update calculations
horheynm Jan 10, 2025
5e952a0
Merge branch 'quant-then-finetune' of github.com:vllm-project/llm-com…
horheynm Jan 10, 2025
c8f56e6
Merge branch 'main' into quant-then-finetune
horheynm Jan 13, 2025
9270c09
Merge branch 'main' into quant-then-finetune
dsikka Jan 14, 2025
336c867
decompress model explicitly
horheynm Jan 14, 2025
d9c806e
Merge branch 'quant-then-finetune' of github.com:vllm-project/llm-com…
horheynm Jan 14, 2025
ab35528
remove skipif
horheynm Jan 14, 2025
00e4d9f
Merge branch 'main' into quant-then-finetune
horheynm Jan 15, 2025
a219e95
lint
horheynm Jan 15, 2025
ca743a5
Merge branch 'main' into quant-then-finetune
horheynm Jan 20, 2025
837430b
Merge branch 'main' into quant-then-finetune
dsikka Jan 20, 2025
305a93f
Merge branch 'main' into quant-then-finetune
horheynm Jan 20, 2025
15884dc
comment
horheynm Jan 22, 2025
6e7bfa6
unindent
horheynm Jan 22, 2025
6884b78
Merge branch 'main' into quant-then-finetune
dsikka Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/llmcompressor/pytorch/utils/sparsification.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,14 @@ def params_quantized(self) -> int:
"""
:return: number of parameters across quantized layers
"""
return sum(
torch.numel(self.trainable_params[f"{name}.weight"])
+ (
torch.numel(self.trainable_params[f"{name}.bias"])
if hasattr(layer, "bias") and layer.bias is not None
else 0
)
for (name, layer) in get_quantized_layers(self.module)
)
num_params = 0
for name, layer in get_quantized_layers(self.module):
if getattr(layer, "weight", None) is not None:
num_params += torch.numel(layer.weight)
if getattr(layer, "bias", None) is not None:
num_params += torch.numel(layer.bias)

return num_params

@property
def params_quantized_percent(self) -> float:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,23 @@
import os
import shutil
import unittest
from pathlib import Path

import pytest
from transformers import AutoModelForCausalLM
from transformers.utils.quantization_config import CompressedTensorsConfig

from llmcompressor.core import create_session
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot, train


@pytest.mark.unit
@pytest.mark.skipif(
"CADENCE" in os.environ
and (os.environ["CADENCE"] == "weekly" or os.environ["CADENCE"] == "nightly"),
reason="Don't run for weekly and nightly tests as those use multi gpu "
"runners and this test fails when ngpu>1",
)
class TestOneshotThenFinetune(unittest.TestCase):
def setUp(self):
self.output = Path("./finetune_output")
self.quantization_config = CompressedTensorsConfig(run_compressed=False)

def test_oneshot_then_finetune(self):
from transformers import AutoModelForCausalLM

from llmcompressor.core import create_session
from llmcompressor.transformers import oneshot, train

def test_oneshot_sparsification_then_finetune(self):
recipe_str = "tests/llmcompressor/transformers/obcq/recipes/test_tiny2.yaml"
model = AutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M", device_map="auto"
Expand All @@ -47,8 +42,12 @@ def test_oneshot_then_finetune(self):
recipe_str = (
"tests/llmcompressor/transformers/finetune/test_finetune_recipe.yaml"
)

# Explictly decompress the model for training using quantization_config
model = AutoModelForCausalLM.from_pretrained(
self.output / "oneshot_out", device_map="auto"
self.output / "oneshot_out",
device_map="auto",
quantization_config=self.quantization_config,
)
distill_teacher = AutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M", device_map="auto"
Expand All @@ -73,7 +72,12 @@ def test_oneshot_then_finetune(self):
)

# test reloading checkpoint and final model
model = AutoModelForCausalLM.from_pretrained(output_dir, device_map="auto")
# verify checkpoint reloading and can carry out finetune
# with the saved model
# Explictly decompress the model for training using quantization_config
model = AutoModelForCausalLM.from_pretrained(
output_dir, device_map="auto", quantization_config=self.quantization_config
)
with create_session():
train(
model=model,
Expand All @@ -88,5 +92,71 @@ def test_oneshot_then_finetune(self):
resume_from_checkpoint=True, # use last checkpoint
)

def test_oneshot_quantization_then_finetune(self):
recipe = QuantizationModifier(
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]
)

model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
device_map="auto",
)
dataset = "open_platypus"
concatenate_data = False
num_calibration_samples = 64
output_dir = self.output / "oneshot_out"
splits = {"calibration": "train[:10%]"}

with create_session():
oneshot(
model=model,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
concatenate_data=concatenate_data,
splits=splits,
)

from transformers.utils.quantization_config import CompressedTensorsConfig

quantization_config = CompressedTensorsConfig(run_compressed=False)
model = AutoModelForCausalLM.from_pretrained(
output_dir,
device_map="auto",
quantization_config=quantization_config,
)
dataset = "open_platypus"
concatenate_data = False
output_dir = self.output / "finetune_out"
splits = {"calibration": "train[:10%]", "train": "train[:10%]"}

with create_session():
train(
model=model,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
concatenate_data=concatenate_data,
splits=splits,
)

# test reloading checkpoint and final model
model = AutoModelForCausalLM.from_pretrained(
output_dir, device_map="auto", quantization_config=quantization_config
)
with create_session():
train(
model=model,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
concatenate_data=concatenate_data,
splits=splits,
resume_from_checkpoint=True, # use last checkpoint
)

def tearDown(self):
shutil.rmtree(self.output)
Loading