Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support dynamically loading Lora adapter from HuggingFace #6234

Merged
merged 10 commits into from
Jul 22, 2024
4 changes: 2 additions & 2 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def test_prefill_schedule_max_lora():
lora_request=LoRARequest(
lora_name=str(i),
lora_int_id=i + 1,
lora_local_path="abc"))
lora_path="abc"))
waiting.append(seq_group)
# Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular
Expand Down Expand Up @@ -760,7 +760,7 @@ def test_schedule_swapped_max_loras():
lora_request=LoRARequest(
lora_name=str(i),
lora_int_id=i + 1,
lora_local_path="abc"))
lora_path="abc"))
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
Expand Down
10 changes: 8 additions & 2 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,14 @@ def dummy_model_gate_up() -> nn.Module:


@pytest.fixture(scope="session")
def sql_lora_files():
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading.
return "yard1/llama-2-7b-sql-lora-test"


@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id)


@pytest.fixture(scope="session")
Expand Down
39 changes: 39 additions & 0 deletions tests/lora/test_lora_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List

import pytest

from vllm.lora.models import LoRAModel
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.llama import LlamaForCausalLM

# Provide absolute path and huggingface lora ids
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]


@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_name = request.getfixturevalue(lora_fixture_name)
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)

lora_path = get_adapter_absolute_path(lora_name)

# lora loading should work for either absolute path and hugggingface id.
lora_model = LoRAModel.from_local_checkpoint(
lora_path,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)

# Assertions to ensure the model is loaded correctly
assert lora_model is not None, "LoRAModel is not loaded correctly"
57 changes: 56 additions & 1 deletion tests/lora/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections import OrderedDict
from unittest.mock import patch

import pytest
from huggingface_hub.utils import HfHubHTTPError
from torch import nn

from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
from vllm.lora.utils import (get_adapter_absolute_path,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.utils import LRUCache


Expand Down Expand Up @@ -182,3 +185,55 @@ def test_lru_cache():
assert 2 in cache
assert 4 in cache
assert 6 in cache


# Unit tests for get_adapter_absolute_path
@patch('os.path.isabs')
def test_get_adapter_absolute_path_absolute(mock_isabs):
path = '/absolute/path/to/lora'
mock_isabs.return_value = True
assert get_adapter_absolute_path(path) == path


@patch('os.path.expanduser')
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
# Path with ~ that needs to be expanded
path = '~/relative/path/to/lora'
absolute_path = '/home/user/relative/path/to/lora'
mock_expanduser.return_value = absolute_path
assert get_adapter_absolute_path(path) == absolute_path


@patch('os.path.exists')
@patch('os.path.abspath')
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
# Relative path that exists locally
path = 'relative/path/to/lora'
absolute_path = '/absolute/path/to/lora'
mock_exist.return_value = True
mock_abspath.return_value = absolute_path
assert get_adapter_absolute_path(path) == absolute_path


@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_adapter_absolute_path_huggingface(mock_exist,
mock_snapshot_download):
# Hugging Face model identifier
path = 'org/repo'
absolute_path = '/mock/snapshot/path'
mock_exist.return_value = False
mock_snapshot_download.return_value = absolute_path
assert get_adapter_absolute_path(path) == absolute_path


@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_adapter_absolute_path_huggingface_error(mock_exist,
mock_snapshot_download):
# Hugging Face model identifier with download error
path = 'org/repo'
mock_exist.return_value = False
mock_snapshot_download.side_effect = HfHubHTTPError(
"failed to query model info")
assert get_adapter_absolute_path(path) == path
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PromptAdapterPath:
@dataclass
class LoRAModulePath:
name: str
local_path: str
path: str


class OpenAIServing:
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_local_path=lora.local_path,
lora_path=lora.path,
) for i, lora in enumerate(lora_modules, start=1)
]

Expand Down
42 changes: 39 additions & 3 deletions vllm/lora/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
import warnings
from dataclasses import dataclass, field
from typing import Optional

from vllm.adapter_commons.request import AdapterRequest
Expand All @@ -20,10 +21,25 @@ class LoRARequest(AdapterRequest):

lora_name: str
lora_int_id: int
lora_local_path: str
lora_path: str = ""
lora_local_path: Optional[str] = field(default=None, repr=False)
long_lora_max_len: Optional[int] = None
__hash__ = AdapterRequest.__hash__

def __post_init__(self):
if 'lora_local_path' in self.__dict__:
warnings.warn(
"The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'lora_path' instead.",
DeprecationWarning,
stacklevel=2)
if not self.lora_path:
self.lora_path = self.lora_local_path or ""

# Ensure lora_path is not empty
assert self.lora_path, "lora_path can not be empty"

@property
def adapter_id(self):
return self.lora_int_id
Expand All @@ -32,6 +48,26 @@ def adapter_id(self):
def name(self):
return self.lora_name

@property
def path(self):
return self.lora_path

@property
def local_path(self):
return self.lora_local_path
warnings.warn(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead.",
DeprecationWarning,
stacklevel=2)
return self.lora_path

@local_path.setter
def local_path(self, value):
warnings.warn(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead.",
DeprecationWarning,
stacklevel=2)
self.lora_path = value
47 changes: 47 additions & 0 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
from typing import List, Optional, Set, Tuple, Type

import huggingface_hub
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
HFValidationError, RepositoryNotFoundError)
from torch import nn
from transformers import PretrainedConfig

Expand Down Expand Up @@ -105,3 +109,46 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"

raise ValueError(f"{name} is unsupported LoRA weight")


def get_adapter_absolute_path(lora_path: str) -> str:
"""
Resolves the given lora_path to an absolute local path.

If the lora_path is identified as a Hugging Face model identifier,
it will download the model and return the local snapshot path.
Otherwise, it treats the lora_path as a local file path and
converts it to an absolute path.

Parameters:
lora_path (str): The path to the lora model, which can be an absolute path,
a relative path, or a Hugging Face model identifier.

Returns:
str: The resolved absolute local path to the lora model.
"""

# Check if the path is an absolute path. Return it no matter exists or not.
if os.path.isabs(lora_path):
return lora_path

# If the path starts with ~, expand the user home directory.
if lora_path.startswith('~'):
return os.path.expanduser(lora_path)

# Check if the expanded relative path exists locally.
if os.path.exists(lora_path):
return os.path.abspath(lora_path)

# If the path does not exist locally, assume it's a Hugging Face repo.
try:
local_snapshot_path = huggingface_hub.snapshot_download(
repo_id=lora_path)
except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
HFValidationError):
# Handle errors that may occur during the download
# Return original path instead instead of throwing error here
logger.exception("Error downloading the HuggingFace model")
return lora_path

return local_snapshot_path
7 changes: 4 additions & 3 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path

logger = init_logger(__name__)

Expand Down Expand Up @@ -89,8 +90,9 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
lora_path = get_adapter_absolute_path(lora_request.lora_path)
lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path,
lora_path,
expected_lora_modules,
max_position_embeddings=self.max_position_embeddings,
lora_model_id=lora_request.lora_int_id,
Expand All @@ -102,8 +104,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
embedding_padding_modules=self.embedding_padding_modules,
)
except Exception as e:
raise RuntimeError(
f"Loading lora {lora_request.lora_local_path} failed") from e
raise RuntimeError(f"Loading lora {lora_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank:
raise ValueError(
f"LoRA rank {lora.rank} is greater than max_lora_rank "
Expand Down
5 changes: 2 additions & 3 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
if lora_request is None:
return None
try:
tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
**kwargs)
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
except OSError as e:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
"No tokenizer found in %s, using base model tokenizer instead. "
"(Exception: %s)", lora_request.lora_local_path, e)
"(Exception: %s)", lora_request.lora_path, e)
tokenizer = None
return tokenizer

Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def profile_run(self) -> None:
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
lora_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
Expand Down
Loading