Skip to content

Commit

Permalink
Skip tests on fbcode (#1532)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1532

Skip tests on fbcode. Missing model checkpoints

Differential Revision: D67982501
  • Loading branch information
jainapurva authored and facebook-github-bot committed Jan 9, 2025
1 parent 8259a38 commit e8acf22
Showing 1 changed file with 91 additions and 81 deletions.
172 changes: 91 additions & 81 deletions test/quantization/test_gptq_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor
from torchao.quantization.utils import _lm_eval_available
from torchao.utils import is_fbcode
from torch.testing._internal.common_utils import run_tests

if is_fbcode():
pytest.skip(
"Skipping the test in fbcode due to missing model and tokenizer files"
)

if _lm_eval_available:
hqq_core = pytest.importorskip("hqq.core", reason="requires hqq")
Expand Down Expand Up @@ -246,89 +253,92 @@ def run_eval(self, tasks, limit):

return result

def test_gptq_mt():
precision = torch.bfloat16
device = "cuda"
print("Loading model")
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device="cpu")
model.eval()
print("Model loaded")
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
print("Tokenizer loaded")


blocksize = 128
percdamp = 0.01
groupsize = 64
calibration_tasks = ["wikitext"]
calibration_limit = None
calibration_seq_length = 100
input_prep_func = prepare_inputs_for_model
pad_calibration_inputs = False
print("Recording inputs")
inputs = (
InputRecorder(
tokenizer,
calibration_seq_length,
input_prep_func,
pad_calibration_inputs,
model.config.vocab_size,
device="cpu",
)
.record_inputs(
calibration_tasks,
calibration_limit,
)
.get_inputs()
)
print("Inputs recorded")
quantizer = Int4WeightOnlyGPTQQuantizer(
blocksize,
percdamp,
groupsize,
)

precision = torch.bfloat16
device = "cuda"
print("Loading model")
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device="cpu")
model.eval()
print("Model loaded")
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
print("Tokenizer loaded")


blocksize = 128
percdamp = 0.01
groupsize = 64
calibration_tasks = ["wikitext"]
calibration_limit = None
calibration_seq_length = 100
input_prep_func = prepare_inputs_for_model
pad_calibration_inputs = False
print("Recording inputs")
inputs = (
InputRecorder(
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
multi = [
MultiTensor([inp for inp, _ in inputs]),
MultiTensor([inds for _, inds in inputs]),
]
print("Quantizing model")
model = quantizer.quantize(model, multi).cuda()
print("Model quantized")
print("Saving model and fixing state dict")
regular_state_dict = model.state_dict() # defaultdict(torch.tensor)
for key, value in model.state_dict().items():
if isinstance(value, MultiTensor):
regular_state_dict[key] = value.values[0]
else:
regular_state_dict[key] = value

model = Transformer.from_name(checkpoint_path.parent.name)
remove = [k for k in regular_state_dict if "kv_cache" in k]
for k in remove:
del regular_state_dict[k]

model.load_state_dict(regular_state_dict, assign=True)
torch.save(model.state_dict(), "model.pth")
print("Running evaluation")
result = TransformerEvalWrapper(
model.to(device), # quantized model needs to run on cuda
tokenizer,
calibration_seq_length,
input_prep_func,
pad_calibration_inputs,
model.config.vocab_size,
device="cpu",
)
.record_inputs(
calibration_tasks,
calibration_limit,
model.config.block_size,
prepare_inputs_for_model,
).run_eval(
["wikitext"],
None,
)
.get_inputs()
)
print("Inputs recorded")
quantizer = Int4WeightOnlyGPTQQuantizer(
blocksize,
percdamp,
groupsize,
)

model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
multi = [
MultiTensor([inp for inp, _ in inputs]),
MultiTensor([inds for _, inds in inputs]),
]
print("Quantizing model")
model = quantizer.quantize(model, multi).cuda()
print("Model quantized")
print("Saving model and fixing state dict")
regular_state_dict = model.state_dict() # defaultdict(torch.tensor)
for key, value in model.state_dict().items():
if isinstance(value, MultiTensor):
regular_state_dict[key] = value.values[0]
else:
regular_state_dict[key] = value

model = Transformer.from_name(checkpoint_path.parent.name)
remove = [k for k in regular_state_dict if "kv_cache" in k]
for k in remove:
del regular_state_dict[k]

model.load_state_dict(regular_state_dict, assign=True)
torch.save(model.state_dict(), "model.pth")
print("Running evaluation")
result = TransformerEvalWrapper(
model.to(device), # quantized model needs to run on cuda
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
).run_eval(
["wikitext"],
None,
)

if __name__ == "__main__":
run_tests()

# wikitext: {'word_perplexity,none': 12.523175352665858, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6042723245990418, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.681919059499152, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

0 comments on commit e8acf22

Please sign in to comment.