diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 50d7e9133e0e8..cd45040bcca5d 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -119,6 +119,16 @@ def __init__( self.rank = rank self.loras: Dict[str, LoRALayerWeights] = loras + def clone(self, lora_model_id: int) -> "LoRAModel": + """Return a copy of the object with different ids. + + Will share the underlying tensors.""" + return self.__class__( + lora_model_id, + rank=self.rank, + loras=self.loras.copy(), + ) + @property def extra_vocab_size(self) -> int: return max(lora.extra_vocab_size diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index ec3c10c591a18..377f561cceaf2 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod, abstractproperty -from typing import Any, Dict, List, Set, Type +from contextlib import contextmanager +from typing import Any, Dict, List, Literal, Set, Type, Union import torch @@ -25,6 +26,17 @@ def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, self.device = device self.lora_config = lora_config + # If False, do not cache. If None, cache is empty. + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False + @abstractproperty def is_enabled(self) -> bool: ... @@ -174,9 +186,15 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: if lora_request.lora_int_id in self.list_loras(): return False - return self._lora_manager.add_lora( - self._lora_manager.create_dummy_lora(lora_request.lora_int_id, - rank, self.embedding_modules)) + if isinstance(self._cached_dummy_lora, LoRAModel): + dummy_lora = self._cached_dummy_lora.clone( + lora_request.lora_int_id) + else: + dummy_lora = self._lora_manager.create_dummy_lora( + lora_request.lora_int_id, rank, self.embedding_modules) + if self._cached_dummy_lora is None: + self._cached_dummy_lora = dummy_lora + return self._lora_manager.add_lora(dummy_lora) def add_lora(self, lora_request: LoRARequest) -> bool: if lora_request.lora_int_id in self.list_loras(): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c96f13c590fc4..46c6730645c1b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -835,20 +835,21 @@ def profile_run(self) -> None: dummy_lora_requests = [] dummy_lora_requests_per_seq = [] if self.lora_config: - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_local_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens.