Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loras on quantized models #2828

Closed
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"

[tool.codespell]
ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts"
skip = "./tests/prompts,./tests/lora/test_llama.py"
9 changes: 8 additions & 1 deletion tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
import tempfile
from collections import OrderedDict
from time import sleep
from unittest.mock import patch, MagicMock

import pytest
Expand All @@ -26,9 +27,10 @@ def cleanup():
destroy_model_parallel()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
ray.shutdown()
sleep(1)
gc.collect()
torch.cuda.empty_cache()
ray.shutdown()


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -121,6 +123,11 @@ def sql_lora_files():
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")


@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")


@pytest.fixture(scope="session")
def mixtral_lora_files():
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
Expand Down
179 changes: 131 additions & 48 deletions tests/lora/test_llama.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,43 @@
import pytest
import ray
from dataclasses import dataclass
from typing import List, Optional

import vllm
from vllm.lora.request import LoRARequest
from .conftest import cleanup

MODEL_PATH = "meta-llama/Llama-2-7b-hf"

@dataclass
class ModelWithQuantization:
model_path: str
quantization: Optional[str]

def do_sample(llm, lora_path: str, lora_id: int):
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501

MODELS: List[ModelWithQuantization] = [
ModelWithQuantization(model_path="TinyLlama/TinyLlama-1.1B-Chat-v0.3",
quantization=None),
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
quantization="AWQ"),
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
quantization="GPTQ"),
]


def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256):
raw_prompts = [
"Give me an orange-ish brown color",
"Give me a neon pink color",
]

def format_prompt_tuples(prompt):
return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

prompts = [format_prompt_tuples(p) for p in raw_prompts]

sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])
max_tokens=max_tokens,
stop=["<|im_end|>"])
outputs = llm.generate(
prompts,
sampling_params,
Expand All @@ -35,106 +53,171 @@ def do_sample(llm, lora_path: str, lora_id: int):
return generated_texts


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_llama_lora(sql_lora_files, tp_size):
def test_llama_lora(tinyllama_lora_files, model, tp_size):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

llm = vllm.LLM(MODEL_PATH,
llm = vllm.LLM(model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=tp_size)

expected_no_lora_output = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501
]
expected_lora_output = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501
]
max_model_len=400,
tensor_parallel_size=tp_size,
quantization=model.quantization)

if model.quantization is None:
expected_no_lora_output = [
"Here are some examples of orange-brown colors",
"I'm sorry, I don't have"
]
expected_lora_output = [
"#ff8050",
"#ff8080",
]
elif model.quantization == "AWQ":
expected_no_lora_output = [
"I'm sorry, I don't understand",
"I'm sorry, I don't understand",
]
expected_lora_output = [
"#f07700: A v",
"#f00000: A v",
]
elif model.quantization == "GPTQ":
expected_no_lora_output = [
"I'm sorry, I don't have",
"I'm sorry, I don't have",
]
expected_lora_output = [
"#f08800: This is",
"#f07788 \n#",
]

def expect_match(output, expected_output):
# HACK: GPTQ lora outputs are just incredibly unstable.
# Assert that the outputs changed.
if (model.quantization == "GPTQ"
and expected_output is expected_lora_output):
assert output != expected_no_lora_output
for i, o in enumerate(output):
assert o.startswith(
'#'), f"Expected example {i} to start with # but got {o}"
return
assert output == expected_output

max_tokens = 10

print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
output = do_sample(llm,
tinyllama_lora_files,
lora_id=0,
max_tokens=max_tokens)
expect_match(output, expected_no_lora_output)

print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output
output = do_sample(llm,
tinyllama_lora_files,
lora_id=1,
max_tokens=max_tokens)
expect_match(output, expected_lora_output)

print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
output = do_sample(llm,
tinyllama_lora_files,
lora_id=0,
max_tokens=max_tokens)
expect_match(output, expected_no_lora_output)

print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output
output = do_sample(llm,
tinyllama_lora_files,
lora_id=2,
max_tokens=max_tokens)
expect_match(output, expected_lora_output)

print("removing lora")

del llm
cleanup()


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.skip("Requires multiple GPUs")
def test_llama_tensor_parallel_equality(sql_lora_files):
def test_llama_tensor_parallel_equality(tinyllama_lora_files, model):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")

llm_tp1 = vllm.LLM(MODEL_PATH,
llm_tp1 = vllm.LLM(model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1)
tensor_parallel_size=1,
quantization=model.quantization)
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)

del llm_tp1
cleanup()

llm_tp2 = vllm.LLM(MODEL_PATH,
llm_tp2 = vllm.LLM(model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=2)
output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1)
tensor_parallel_size=2,
quantization=model.quantization)
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)

del llm_tp2
cleanup()

assert output_tp1 == output_tp2

llm_tp4 = vllm.LLM(MODEL_PATH,
llm_tp4 = vllm.LLM(model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4)
output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1)
tensor_parallel_size=4,
quantization=model.quantization)
output_tp4 = do_sample(llm_tp4, tinyllama_lora_files, lora_id=1)

del llm_tp4
cleanup()

assert output_tp1 == output_tp4


def test_llama_lora_warmup(sql_lora_files):
@pytest.mark.parametrize("model", MODELS)
def test_llama_lora_warmup(model):
"""Test that the LLM initialization works with a warmup LORA path and
is more conservative"""

@ray.remote(num_gpus=1)
def get_num_gpu_blocks_lora():
llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16)
llm = vllm.LLM(model=model.model_path,
enable_lora=True,
max_num_seqs=16,
quantization=model.quantization)
num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks

del llm
cleanup()

return num_gpu_blocks_lora_warmup

@ray.remote(num_gpus=1)
def get_num_gpu_blocks_no_lora():
llm = vllm.LLM(MODEL_PATH, max_num_seqs=16)
llm = vllm.LLM(model=model.model_path,
max_num_seqs=16,
quantization=model.quantization)
num_gpu_blocks_no_lora_warmup = (
llm.llm_engine.cache_config.num_gpu_blocks)

del llm
cleanup()

return num_gpu_blocks_no_lora_warmup

num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote())
Expand Down
3 changes: 0 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,6 @@ def verify_with_model_config(self, model_config: ModelConfig):
self.lora_dtype = model_config.dtype
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)
if model_config.quantization is not None:
raise ValueError(
"LoRA is not supported with quantized models yet.")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably should throw for non-tested quant types. Should have an allowlist

def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528:
Expand Down
Loading
Loading