Skip to content

Commit

Permalink
[LoRA] Adds support for bias in LoRA (vllm-project#5733)
Browse files Browse the repository at this point in the history
Signed-off-by: Umesh Deshpande <[email protected]>
Co-authored-by: Umesh Deshpande <[email protected]>
  • Loading branch information
followumesh and Umesh Deshpande authored Nov 12, 2024
1 parent aa3ded2 commit 9ee414d
Show file tree
Hide file tree
Showing 10 changed files with 456 additions and 20 deletions.
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id)


@pytest.fixture(scope="session")
def lora_bias_files():
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")


@pytest.fixture(scope="session")
def mixtral_lora_files():
# Note: this module has incorrect adapter_config.json to test
Expand Down
52 changes: 52 additions & 0 deletions tests/lora/test_lora_bias_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest

MODEL_PATH = "ibm-granite/granite-3b-code-base"


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
generated_texts: List[str] = []
for output in outputs:
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
return generated_texts


@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_lora_rank=8,
max_loras=1,
enable_lora_bias=lora_bias,
tensor_parallel_size=1,
fully_sharded_loras=fully_sharded)

print("lora adapter created")
output1 = do_sample(llm, lora_bias_files, lora_id=0)

print("lora")
output2 = do_sample(llm, lora_bias_files, lora_id=1)

if lora_bias:
assert output1 != output2
else:
assert output1 == output2
14 changes: 9 additions & 5 deletions tests/lora/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,40 @@

def test_parse_fine_tuned_lora_name_valid():
fixture = {
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
("base_model.model.lm_head.lora_A.weight", "lm_head", True, False),
("base_model.model.lm_head.lora_B.weight", "lm_head", False, False),
(
"base_model.model.model.embed_tokens.lora_embedding_A",
"model.embed_tokens",
True,
False,
),
(
"base_model.model.model.embed_tokens.lora_embedding_B",
"model.embed_tokens",
False,
False,
),
(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"model.layers.9.mlp.down_proj",
True,
False,
),
(
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"model.layers.9.mlp.down_proj",
False,
False,
),
}
for name, module_name, is_lora_a in fixture:
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
for name, module_name, is_lora_a, is_bias in fixture:
assert (module_name, is_lora_a,
is_bias) == parse_fine_tuned_lora_name(name)


def test_parse_fine_tuned_lora_name_invalid():
fixture = {
"weight",
"base_model.weight",
"base_model.model.weight",
}
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,7 @@ class LoRAConfig:
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
bias_enabled: bool = False

def __post_init__(self):
# Setting the maximum rank to 256 should be able to satisfy the vast
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class EngineArgs:
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
max_lora_rank: int = 16
enable_prompt_adapter: bool = False
Expand Down Expand Up @@ -584,6 +585,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--enable-lora-bias',
action='store_true',
help='If True, enable bias for LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
Expand Down Expand Up @@ -1148,6 +1152,7 @@ def create_engine_config(self) -> VllmConfig:
and parallel_config.use_ray),
policy=self.scheduling_policy)
lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
Expand Down
33 changes: 33 additions & 0 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def apply(self, x: torch.Tensor,
self.lora_b_stacked,
add_input=True)
# now have column partitioned output

if self.bias_stacked is not None:
self.bias_stacked = self.bias_stacked.view(
-1, self.bias_stacked.shape[-1])
self.bias_stacked = self.bias_stacked[
self.punica_wrapper.token_lora_indices]
output += self.bias_stacked

output = output.view(*out_orig_shape)
return output

Expand Down Expand Up @@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
left_offset = 0
for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2]

if layer.bias_stacked is not None:
bias = layer.bias_stacked[idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[layer.punica_wrapper.token_lora_indices]
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
output[:, left_offset:left_offset + shard_size] += bias

layer.punica_wrapper.add_expand_slice(
output,
buffers[idx],
Expand Down Expand Up @@ -295,6 +312,15 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
lora_b = lora_b[:, start_idx:end_idx]
return lora_b

def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
shard_size = self.bias_stacked.shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias

def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)

Expand All @@ -318,6 +344,13 @@ def apply(self, x: torch.Tensor) -> torch.Tensor:
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size

if self.bias_stacked is not None:
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
bias = bias[self.punica_wrapper.token_lora_indices]
bias[self.punica_wrapper.token_lora_indices == -1] = 0
output += bias

self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked, start_idx,
shard_size)
Expand Down
Loading

0 comments on commit 9ee414d

Please sign in to comment.