Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Faster startup for LoRA enabled models #4634

Merged
merged 4 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ def __init__(
self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras

def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.

Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)

@property
def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size
Expand Down
26 changes: 22 additions & 4 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Set, Type
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Set, Type, Union

import torch

Expand All @@ -25,6 +26,17 @@ def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
self.device = device
self.lora_config = lora_config

# If False, do not cache. If None, cache is empty.
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False

@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False

@abstractproperty
def is_enabled(self) -> bool:
...
Expand Down Expand Up @@ -174,9 +186,15 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add assert self._cached_dummy_lora is not False here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would make caching mandatory

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh do we have other use case other than using it for profiling run (which already uses cache)?

Copy link
Collaborator Author

@Yard1 Yard1 May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unit testing as well

return False
return self._lora_manager.add_lora(
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
rank, self.embedding_modules))
if isinstance(self._cached_dummy_lora, LoRAModel):
dummy_lora = self._cached_dummy_lora.clone(
lora_request.lora_int_id)
else:
dummy_lora = self._lora_manager.create_dummy_lora(
lora_request.lora_int_id, rank, self.embedding_modules)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._lora_manager.add_lora(dummy_lora)

def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras():
Expand Down
29 changes: 15 additions & 14 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,20 +836,21 @@ def profile_run(self) -> None:
dummy_lora_requests = []
dummy_lora_requests_per_seq = []
if self.lora_config:
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
with self.lora_manager.dummy_lora_cache():
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]

# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
Expand Down
Loading