From 83664e6cfa799be96252c12f3f76ca4e2c937df5 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 24 May 2024 01:37:50 +0000 Subject: [PATCH 1/6] Add base class for LoRA-supported models --- vllm/lora/lora.py | 3 +- vllm/lora/models.py | 5 ++-- vllm/model_executor/model_loader/loader.py | 9 ++++-- vllm/model_executor/models/baichuan.py | 10 ++++--- vllm/model_executor/models/chatglm.py | 8 ++++-- vllm/model_executor/models/decilm.py | 4 +-- vllm/model_executor/models/gemma.py | 9 +++--- vllm/model_executor/models/gpt_bigcode.py | 8 ++++-- vllm/model_executor/models/llama.py | 8 ++++-- vllm/model_executor/models/lora_base.py | 32 ++++++++++++++++++++++ vllm/model_executor/models/minicpm.py | 11 +++++--- vllm/model_executor/models/mixtral.py | 8 ++++-- vllm/model_executor/models/phi.py | 13 +++++---- vllm/model_executor/models/qwen2.py | 9 +++--- vllm/model_executor/models/xverse.py | 2 +- vllm/worker/model_runner.py | 11 +++----- 16 files changed, 101 insertions(+), 49 deletions(-) create mode 100644 vllm/model_executor/models/lora_base.py diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index d7794aa7cd35c..fd3b4e4e8fd82 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -1,6 +1,7 @@ from typing import List, Optional import torch +import torch.types from vllm.utils import is_pin_memory_available @@ -63,7 +64,7 @@ def create_dummy_lora_weights( output_dim: int, rank: int, dtype: torch.dtype, - device: torch.device, + device: torch.types.Device, embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() lora_a = torch.zeros([input_dim, rank], diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 3e82856866d85..c218b59ab5d06 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -18,6 +18,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.models.lora_base import LoRASupportedModelBase from vllm.utils import LRUCache, is_pin_memory_available logger = init_logger(__name__) @@ -363,7 +364,7 @@ class LoRAModelManager: def __init__( self, - model: nn.Module, + model: LoRASupportedModelBase, max_num_seqs: int, max_num_batched_tokens: int, vocab_size: int, @@ -411,7 +412,7 @@ def __init__( # embeddings_indices self.indices_len: List[Optional[int]] = [None] * 4 - self.model: nn.Module = model + self.model = model if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = copy.deepcopy( self.model.supported_lora_modules) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 45ea8160a801b..7680497d9358a 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -26,6 +26,7 @@ download_weights_from_hf, filter_files_not_needed_for_inference, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) +from vllm.model_executor.models.lora_base import LoRASupportedModelBase from vllm.model_executor.models.vlm_base import VisionLanguageModelBase logger = init_logger(__name__) @@ -61,7 +62,9 @@ def _get_model_initialization_kwargs( ) -> Dict[str, Any]: """Get extra kwargs for model initialization.""" extra_kwargs = {} - if hasattr(model_class, "supported_lora_modules"): + + if issubclass(model_class, LoRASupportedModelBase): + # lora_config=None is used to disable LoRA extra_kwargs["lora_config"] = lora_config elif lora_config: raise ValueError( @@ -69,13 +72,15 @@ def _get_model_initialization_kwargs( "but LoRA is enabled. Support for this model may " "be added in the future. If this is important to you, " "please open an issue on github.") - elif issubclass(model_class, VisionLanguageModelBase): + + if issubclass(model_class, VisionLanguageModelBase): if vision_language_config is None: raise ValueError("Provide `image_input_type` and other vision " "related configurations through LLM entrypoint " "or engine arguments.") extra_kwargs["vision_language_config"] = vision_language_config + return extra_kwargs diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index babb92e7cdcef..178513eda2190 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -45,6 +45,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .lora_base import LoRASupportedModelBase + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) @@ -292,7 +294,7 @@ def forward( return hidden_states -class BaiChuanBaseForCausalLM(nn.Module): +class BaiChuanBaseForCausalLM(LoRASupportedModelBase): packed_modules_mapping = { "W_pack": ["W_pack"], "gate_up_proj": [ @@ -312,14 +314,14 @@ class BaiChuanBaseForCausalLM(nn.Module): def __init__( self, - config, + config: PretrainedConfig, position_embedding: str, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__() - self.config = config + super().__init__(config, lora_config) + self.quant_config = quant_config self.model = BaiChuanModel(config, position_embedding, cache_config, quant_config) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index e3a5e43e23e1c..319d6874383dc 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -28,6 +28,8 @@ from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig +from .lora_base import LoRASupportedModelBase + class GLMAttention(nn.Module): @@ -322,7 +324,7 @@ def forward( return hidden_states -class ChatGLMForCausalLM(nn.Module): +class ChatGLMForCausalLM(LoRASupportedModelBase): packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"] @@ -344,8 +346,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__() - self.config: ChatGLMConfig = config + super().__init__(config, lora_config) + self.quant_config = quant_config self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index e293ee491908d..65b409a2a15a0 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -26,7 +26,7 @@ from typing import Iterable, Optional, Tuple import torch -from transformers import PretrainedConfig +from transformers import LlamaConfig from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.quantization.base_config import ( @@ -55,7 +55,7 @@ class DeciLMForCausalLM(LlamaForCausalLM): def __init__( self, - config: Optional[PretrainedConfig] = None, + config: LlamaConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 27dda00b66af4..f80e29f532d94 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -41,6 +41,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .lora_base import LoRASupportedModelBase + logger = init_logger(__name__) @@ -288,7 +290,7 @@ def forward( return hidden_states -class GemmaForCausalLM(nn.Module): +class GemmaForCausalLM(LoRASupportedModelBase): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -319,9 +321,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - del lora_config # Unused. - super().__init__() - self.config = config + super().__init__(config, lora_config) + self.quant_config = quant_config self.model = GemmaModel(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 69b75763e9a3d..9ebbad0d8a530 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -41,6 +41,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .lora_base import LoRASupportedModelBase + class GPTBigCodeAttention(nn.Module): @@ -230,7 +232,7 @@ def forward( return hidden_states -class GPTBigCodeForCausalLM(nn.Module): +class GPTBigCodeForCausalLM(LoRASupportedModelBase): packed_modules_mapping = {"c_attn": ["c_attn"]} supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"] @@ -249,8 +251,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__() - self.config = config + super().__init__(config, lora_config) + self.quant_config = quant_config self.transformer = GPTBigCodeModel(config, cache_config, quant_config, lora_config) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f43a40a0bfd34..d953b90f56063 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -49,6 +49,8 @@ from vllm.sequence import SamplerOutput from vllm.utils import is_hip, print_warning_once +from .lora_base import LoRASupportedModelBase + class LlamaMLP(nn.Module): @@ -297,7 +299,7 @@ def forward( return hidden_states -class LlamaForCausalLM(nn.Module): +class LlamaForCausalLM(LoRASupportedModelBase): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -328,8 +330,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__() - self.config = config + super().__init__(config, lora_config) + self.model = LlamaModel(config, cache_config, quant_config, diff --git a/vllm/model_executor/models/lora_base.py b/vllm/model_executor/models/lora_base.py new file mode 100644 index 0000000000000..699093402e5df --- /dev/null +++ b/vllm/model_executor/models/lora_base.py @@ -0,0 +1,32 @@ +from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional + +from torch import nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig + +if TYPE_CHECKING: + from vllm.lora.models import LoRAModelManager + + +class LoRASupportedModelBase(nn.Module): + """Base class for all models that support LoRA.""" + + packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} + supported_lora_modules: ClassVar[List[str]] = [] + embedding_modules: ClassVar[Dict[str, str]] = {} + embedding_padding_modules: ClassVar[List[str]] = [] + + # Assigned by LoRAModelManager at runtime + lora_manager: "LoRAModelManager" + + def __init__( + self, + config: PretrainedConfig, + # This is None when LoRA is not enabled + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 59fbf8e1b35f2..5c2292c566c31 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -26,6 +26,7 @@ import torch from torch import nn +from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig @@ -51,6 +52,8 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput +from .lora_base import LoRASupportedModelBase + class MiniCPMMoE(nn.Module): """A tensor-parallel MoE implementation that shards each expert @@ -388,7 +391,7 @@ def forward( return hidden_states -class MiniCPMForCausalLM(nn.Module): +class MiniCPMForCausalLM(LoRASupportedModelBase): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -418,13 +421,13 @@ class MiniCPMForCausalLM(nn.Module): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__() - self.config = config + super().__init__(config, lora_config) + self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config self.model = MiniCPMModel(config, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ea95cf7380d54..803c1ec534945 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -52,6 +52,8 @@ from vllm.sequence import SamplerOutput from vllm.utils import print_warning_once +from .lora_base import LoRASupportedModelBase + class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert @@ -439,7 +441,7 @@ def forward( return hidden_states -class MixtralForCausalLM(nn.Module): +class MixtralForCausalLM(LoRASupportedModelBase): fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -470,8 +472,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__() - self.config = config + super().__init__(config, lora_config) + self.model = MixtralModel(config, cache_config, quant_config, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index c8e61735a9bb6..6d77e9f86971a 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -39,7 +39,7 @@ import torch from torch import nn -from transformers import PretrainedConfig +from transformers import PhiConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig @@ -59,6 +59,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .lora_base import LoRASupportedModelBase + class PhiAttention(nn.Module): @@ -229,7 +231,7 @@ def forward( return hidden_states -class PhiForCausalLM(nn.Module): +class PhiForCausalLM(LoRASupportedModelBase): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -250,14 +252,13 @@ class PhiForCausalLM(nn.Module): def __init__( self, - config: PretrainedConfig, + config: PhiConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - del lora_config # Unused. - super().__init__() - self.config = config + super().__init__(config, lora_config) + self.quant_config = quant_config self.model = PhiModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index ec203c3b9001a..c54862af7174e 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -47,6 +47,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .lora_base import LoRASupportedModelBase + class Qwen2MLP(nn.Module): @@ -271,7 +273,7 @@ def forward( return hidden_states -class Qwen2ForCausalLM(nn.Module): +class Qwen2ForCausalLM(LoRASupportedModelBase): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -301,9 +303,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - del lora_config - super().__init__() - self.config = config + super().__init__(config, lora_config) + self.quant_config = quant_config self.model = Qwen2Model(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index dda13d83f89a3..5140d3b9356d4 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -270,7 +270,7 @@ def forward( return hidden_states -class XverseForCausalLM(nn.Module): +class XverseForCausalLM(LoRASupportedModelBase): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 87d5f5c1b9d67..005b2e0412301 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -18,6 +18,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.lora_base import LoRASupportedModelBase from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) @@ -147,14 +148,10 @@ def load_model(self) -> None: self.model_memory_usage / float(2**30)) if self.lora_config: - assert hasattr(self.model, "supported_lora_modules" - ) and self.model.supported_lora_modules, ( - "Model does not support LoRA") - assert hasattr( + assert isinstance( self.model, - "embedding_modules"), "Model does not have embedding_modules" - assert hasattr(self.model, "embedding_padding_modules" - ), "Model does not have embedding_padding_modules" + LoRASupportedModelBase), "Model does not support LoRA" + self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, From 40df63459e2ddb3877b256a87e645861ab0a7ca8 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 24 May 2024 01:44:57 +0000 Subject: [PATCH 2/6] Update docs --- docs/source/models/lora.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index 2278640481a91..b2dbcc8242225 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -4,6 +4,9 @@ Using LoRA adapters =================== This document shows you how to use `LoRA adapters `_ with vLLM on top of a base model. + +LoRA adapters can be used with any vLLM model that inherits from :class:`~vllm.model_executor.models.lora_base.LoRASupportedModelBase`. + Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save them locally with From 8ae78fe1e5862bf5643ba25479a26bbce319b1fb Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 24 May 2024 01:47:24 +0000 Subject: [PATCH 3/6] Fix missing imports --- vllm/model_executor/models/phi.py | 8 ++++---- vllm/model_executor/models/xverse.py | 8 +++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 6d77e9f86971a..2bd40396d96e0 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -65,7 +65,7 @@ class PhiAttention(nn.Module): def __init__(self, - config: PretrainedConfig, + config: PhiConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -133,7 +133,7 @@ def forward( class PhiMLP(nn.Module): def __init__(self, - config: PretrainedConfig, + config: PhiConfig, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -162,7 +162,7 @@ def forward(self, hidden_states): class PhiLayer(nn.Module): def __init__(self, - config: PretrainedConfig, + config: PhiConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -194,7 +194,7 @@ def forward( class PhiModel(nn.Module): def __init__(self, - config: PretrainedConfig, + config: PhiConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 5140d3b9356d4..ceeada3b5ad91 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -45,6 +45,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .lora_base import LoRASupportedModelBase + class XverseMLP(nn.Module): @@ -303,10 +305,10 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - lora_config=None, + lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__() - self.config = config + super().__init__(config, lora_config) + self.quant_config = quant_config self.model = XverseModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) From acf7cbe8ff46d738deb3f46b5519ec2933a682c7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 11 Jun 2024 03:32:59 +0000 Subject: [PATCH 4/6] Use interfaces instead of base class --- docs/source/models/lora.rst | 2 +- vllm/lora/models.py | 5 +- vllm/model_executor/model_loader/loader.py | 17 +-- vllm/model_executor/models/baichuan.py | 11 +- vllm/model_executor/models/chatglm.py | 11 +- vllm/model_executor/models/gemma.py | 11 +- vllm/model_executor/models/gpt_bigcode.py | 11 +- vllm/model_executor/models/interfaces.py | 118 +++++++++++++++++++++ vllm/model_executor/models/llama.py | 11 +- vllm/model_executor/models/llava.py | 22 ++-- vllm/model_executor/models/llava_next.py | 20 ++-- vllm/model_executor/models/lora_base.py | 32 ------ vllm/model_executor/models/minicpm.py | 11 +- vllm/model_executor/models/mixtral.py | 11 +- vllm/model_executor/models/phi.py | 11 +- vllm/model_executor/models/qwen2.py | 11 +- vllm/model_executor/models/vlm_base.py | 12 --- vllm/model_executor/models/xverse.py | 11 +- vllm/worker/model_runner.py | 6 +- 19 files changed, 235 insertions(+), 109 deletions(-) create mode 100644 vllm/model_executor/models/interfaces.py delete mode 100644 vllm/model_executor/models/lora_base.py delete mode 100644 vllm/model_executor/models/vlm_base.py diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index b2dbcc8242225..934887a607a6a 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -5,7 +5,7 @@ Using LoRA adapters This document shows you how to use `LoRA adapters `_ with vLLM on top of a base model. -LoRA adapters can be used with any vLLM model that inherits from :class:`~vllm.model_executor.models.lora_base.LoRASupportedModelBase`. +LoRA adapters can be used with any vLLM model that implements :class:`~vllm.model_executor.models.interfaces.SupportsLoRA`. Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save them locally with diff --git a/vllm/lora/models.py b/vllm/lora/models.py index c218b59ab5d06..b9a43ea0d05a2 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -18,7 +18,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) -from vllm.model_executor.models.lora_base import LoRASupportedModelBase +from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.utils import LRUCache, is_pin_memory_available logger = init_logger(__name__) @@ -364,7 +364,7 @@ class LoRAModelManager: def __init__( self, - model: LoRASupportedModelBase, + model: SupportsLoRA, max_num_seqs: int, max_num_batched_tokens: int, vocab_size: int, @@ -429,7 +429,6 @@ def __init__( self._active_loras: Dict[int, None] = {} self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() - self.model.lora_manager = self @property def capacity(self) -> int: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 724aa500134b2..dd8e24beb7f6d 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -32,8 +32,8 @@ filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.models.lora_base import LoRASupportedModelBase -from vllm.model_executor.models.vlm_base import VisionLanguageModelBase +from vllm.model_executor.models.interfaces import (supports_lora, + supports_vision) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -64,13 +64,14 @@ def _get_quantization_config( def _get_model_initialization_kwargs( - model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig] + model_class: Type[nn.Module], + lora_config: Optional[LoRAConfig], + vlm_config: Optional[VisionLanguageConfig], ) -> Dict[str, Any]: """Get extra kwargs for model initialization.""" extra_kwargs = {} - if issubclass(model_class, LoRASupportedModelBase): + if supports_lora(model_class): # lora_config=None is used to disable LoRA extra_kwargs["lora_config"] = lora_config elif lora_config: @@ -80,13 +81,13 @@ def _get_model_initialization_kwargs( "be added in the future. If this is important to you, " "please open an issue on github.") - if issubclass(model_class, VisionLanguageModelBase): - if vision_language_config is None: + if supports_vision(model_class): + if vlm_config is None: raise ValueError("Provide `image_input_type` and other vision " "related configurations through LLM entrypoint " "or engine arguments.") - extra_kwargs["vision_language_config"] = vision_language_config + extra_kwargs["vlm_config"] = vlm_config return extra_kwargs diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 178513eda2190..abaefa3cf7781 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -45,7 +45,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from .lora_base import LoRASupportedModelBase +from .interfaces import SupportsLoRA def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -294,7 +294,9 @@ def forward( return hidden_states -class BaiChuanBaseForCausalLM(LoRASupportedModelBase): +class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "W_pack": ["W_pack"], "gate_up_proj": [ @@ -320,7 +322,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__(config, lora_config) + super().__init__() + + self.config = config + self.lora_config = lora_config self.quant_config = quant_config self.model = BaiChuanModel(config, position_embedding, cache_config, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 319d6874383dc..bf64538ef54a3 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -28,7 +28,7 @@ from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig -from .lora_base import LoRASupportedModelBase +from .interfaces import SupportsLoRA class GLMAttention(nn.Module): @@ -324,7 +324,9 @@ def forward( return hidden_states -class ChatGLMForCausalLM(LoRASupportedModelBase): +class ChatGLMForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"] @@ -346,7 +348,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__(config, lora_config) + super().__init__() + + self.config = config + self.lora_config = lora_config self.quant_config = quant_config self.max_position_embeddings = getattr(config, "max_sequence_length", diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index f80e29f532d94..859e2ea9c8b1f 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -41,7 +41,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from .lora_base import LoRASupportedModelBase +from .interfaces import SupportsLoRA logger = init_logger(__name__) @@ -290,7 +290,9 @@ def forward( return hidden_states -class GemmaForCausalLM(LoRASupportedModelBase): +class GemmaForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -321,7 +323,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__(config, lora_config) + super().__init__() + + self.config = config + self.lora_config = lora_config self.quant_config = quant_config self.model = GemmaModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 9ebbad0d8a530..f182aa49e67ed 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -41,7 +41,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from .lora_base import LoRASupportedModelBase +from .interfaces import SupportsLoRA class GPTBigCodeAttention(nn.Module): @@ -232,7 +232,9 @@ def forward( return hidden_states -class GPTBigCodeForCausalLM(LoRASupportedModelBase): +class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = {"c_attn": ["c_attn"]} supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"] @@ -251,7 +253,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__(config, lora_config) + super().__init__() + + self.config = config + self.lora_config = lora_config self.quant_config = quant_config self.transformer = GPTBigCodeModel(config, cache_config, quant_config, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py new file mode 100644 index 0000000000000..cd119fc2251c5 --- /dev/null +++ b/vllm/model_executor/models/interfaces.py @@ -0,0 +1,118 @@ +from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, + Union, overload, runtime_checkable) + +from typing_extensions import TypeGuard + +from vllm.config import LoRAConfig, VisionLanguageConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@runtime_checkable +class SupportsVision(Protocol): + """The interface required for all vision language models (VLMs).""" + + supports_vision: ClassVar[Literal[True]] + + def __init__(self, *, vlm_config: VisionLanguageConfig) -> None: + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsVisionType(Protocol): + supports_vision: Literal[True] + + def __call__(self, *, vlm_config: VisionLanguageConfig) -> None: + ... + + +@overload +def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]: + ... + + +@overload +def supports_vision(model: object) -> TypeGuard[SupportsVision]: + ... + + +def supports_vision( + model: Union[Type[object], object], +) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]: + if isinstance(model, type): + return isinstance(model, _SupportsVisionType) + + return isinstance(model, SupportsVision) + + +@runtime_checkable +class SupportsLoRA(Protocol): + """The interface required for all models that support LoRA.""" + + supports_lora: ClassVar[Literal[True]] + + packed_modules_mapping: ClassVar[Dict[str, List[str]]] + supported_lora_modules: ClassVar[List[str]] + embedding_modules: ClassVar[Dict[str, str]] + embedding_padding_modules: ClassVar[List[str]] + + # lora_config is None when LoRA is not enabled + def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsLoRAType(Protocol): + supports_lora: Literal[True] + + packed_modules_mapping: Dict[str, List[str]] + supported_lora_modules: List[str] + embedding_modules: Dict[str, str] + embedding_padding_modules: List[str] + + def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + ... + + +@overload +def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]: + ... + + +@overload +def supports_lora(model: object) -> TypeGuard[SupportsLoRA]: + ... + + +def supports_lora( + model: Union[Type[object], object], +) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: + result = _supports_lora(model) + + if not result: + lora_attrs = ( + "packed_modules_mapping", + "supported_lora_modules", + "embedding_modules", + "embedding_padding_modules", + ) + if any(hasattr(model, attr) for attr in lora_attrs): + logger.warning( + "The model (%s) contains LoRA-specific attributes, " + "but does not set `supports_lora=True`.", model) + + return result + + +def _supports_lora( + model: Union[Type[object], object], +) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: + if isinstance(model, type): + return isinstance(model, _SupportsLoRAType) + + return isinstance(model, SupportsLoRA) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 283d7050bea59..f4918cbfef294 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -49,7 +49,7 @@ from vllm.sequence import SamplerOutput from vllm.utils import is_hip, print_warning_once -from .lora_base import LoRASupportedModelBase +from .interfaces import SupportsLoRA class LlamaMLP(nn.Module): @@ -298,7 +298,9 @@ def forward( return hidden_states -class LlamaForCausalLM(LoRASupportedModelBase): +class LlamaForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -337,7 +339,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__(config, lora_config) + super().__init__() + + self.config = config + self.lora_config = lora_config self.model = LlamaModel(config, cache_config, diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 67b32a08833b6..d08851c847d94 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -21,7 +21,7 @@ from vllm.multimodal.image import get_dummy_image_data from vllm.sequence import SamplerOutput -from .vlm_base import VisionLanguageModelBase +from .interfaces import SupportsVision _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", @@ -87,18 +87,21 @@ class LlavaImageFeatureInputs(TypedDict): @MULTIMODAL_REGISTRY.register_image_feature_input() @MULTIMODAL_REGISTRY.register_image_pixel_input() @MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) -class LlavaForConditionalGeneration(VisionLanguageModelBase): +class LlavaForConditionalGeneration(nn.Module, SupportsVision): + + supports_vision = True def __init__(self, config: LlavaConfig, - vision_language_config: VisionLanguageConfig, + vlm_config: VisionLanguageConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: - super().__init__(vision_language_config) + super().__init__() self.config = config + self.vlm_config = vlm_config - if self.vision_language_config.image_input_type == ( + if self.vlm_config.image_input_type == ( VisionLanguageConfig.ImageInputType.PIXEL_VALUES): self.vision_tower = CLIPVisionModel(config.vision_config) else: @@ -123,11 +126,10 @@ def __init__(self, self.sampler = Sampler() def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: - if list(data.shape[1:]) != list( - self.vision_language_config.image_input_shape[1:]): + if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]): raise ValueError( f"The expected image tensor shape is batch dimension plus " - f"{self.vision_language_config.image_input_shape[1:]}. " + f"{self.vlm_config.image_input_shape[1:]}. " f"You supplied {data.shape}. " f"If you are using vLLM's entrypoint, make sure your " f"supplied image input is consistent with " @@ -140,7 +142,7 @@ def _parse_and_validate_image_input( pixel_values = kwargs.pop("pixel_values", None) image_features = kwargs.pop("image_features", None) - expected_input_type = self.vision_language_config.image_input_type + expected_input_type = self.vlm_config.image_input_type ImageInputType = VisionLanguageConfig.ImageInputType if expected_input_type == ImageInputType.PIXEL_VALUES: @@ -272,7 +274,7 @@ def forward( inputs_embeds = merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, - self.vision_language_config.image_token_id) + self.vlm_config.image_token_id) input_ids = None else: diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 57cbd1e4a6018..b1337c7cd171f 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -26,8 +26,8 @@ from vllm.multimodal.image import ImagePixelData, get_dummy_image_data from vllm.sequence import SamplerOutput, SequenceData +from .interfaces import SupportsVision from .llava import LlavaMultiModalProjector, merge_vision_embeddings -from .vlm_base import VisionLanguageModelBase logger = init_logger(__name__) @@ -107,7 +107,7 @@ def _image_pixel_processor( @MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor) @MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data) -class LlavaNextForConditionalGeneration(VisionLanguageModelBase): +class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): """ Args to `forward()`: input_ids: Flattened (concatenated) input_ids corresponding to a @@ -118,17 +118,19 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase): [1, num_patches, 1176, 1024]. """ + supports_vision = True + def __init__(self, config: LlavaNextConfig, - vision_language_config: VisionLanguageConfig, + vlm_config: VisionLanguageConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: - super().__init__(vision_language_config) + super().__init__() - # Update the type annotation from that of its superclass self.config = config + self.vlm_config = vlm_config - if self.vision_language_config.image_input_type == ( + if self.vlm_config.image_input_type == ( VisionLanguageConfig.ImageInputType.PIXEL_VALUES): self.vision_tower = CLIPVisionModel(config.vision_config) else: @@ -156,7 +158,7 @@ def __init__(self, torch.empty(config.text_config.hidden_size)) def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor: - _, num_channels, _, _ = self.vision_language_config.image_input_shape + _, num_channels, _, _ = self.vlm_config.image_input_shape # Note that this is different from that of vLLM vision_language_config # since the image is resized by the HuggingFace preprocessor @@ -187,7 +189,7 @@ def _parse_and_validate_image_input( image_sizes = kwargs.pop("image_sizes", None) image_features = kwargs.pop("image_features", None) - expected_input_type = self.vision_language_config.image_input_type + expected_input_type = self.vlm_config.image_input_type ImageInputType = VisionLanguageConfig.ImageInputType if expected_input_type == ImageInputType.PIXEL_VALUES: @@ -400,7 +402,7 @@ def forward( inputs_embeds = merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, - self.vision_language_config.image_token_id) + self.vlm_config.image_token_id) input_ids = None else: diff --git a/vllm/model_executor/models/lora_base.py b/vllm/model_executor/models/lora_base.py deleted file mode 100644 index 699093402e5df..0000000000000 --- a/vllm/model_executor/models/lora_base.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional - -from torch import nn -from transformers import PretrainedConfig - -from vllm.config import LoRAConfig - -if TYPE_CHECKING: - from vllm.lora.models import LoRAModelManager - - -class LoRASupportedModelBase(nn.Module): - """Base class for all models that support LoRA.""" - - packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} - supported_lora_modules: ClassVar[List[str]] = [] - embedding_modules: ClassVar[Dict[str, str]] = {} - embedding_padding_modules: ClassVar[List[str]] = [] - - # Assigned by LoRAModelManager at runtime - lora_manager: "LoRAModelManager" - - def __init__( - self, - config: PretrainedConfig, - # This is None when LoRA is not enabled - lora_config: Optional[LoRAConfig] = None, - ) -> None: - super().__init__() - - self.config = config - self.lora_config = lora_config diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 5c2292c566c31..ae17309bd5223 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -52,7 +52,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput -from .lora_base import LoRASupportedModelBase +from .interfaces import SupportsLoRA class MiniCPMMoE(nn.Module): @@ -391,7 +391,9 @@ def forward( return hidden_states -class MiniCPMForCausalLM(LoRASupportedModelBase): +class MiniCPMForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -426,7 +428,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__(config, lora_config) + super().__init__() + + self.config = config + self.lora_config = lora_config self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 338164e258a7d..0bdcb21e514fd 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -54,7 +54,7 @@ from vllm.sequence import SamplerOutput from vllm.utils import print_warning_once -from .lora_base import LoRASupportedModelBase +from .interfaces import SupportsLoRA class MixtralMoE(nn.Module): @@ -474,7 +474,9 @@ def forward( return hidden_states -class MixtralForCausalLM(LoRASupportedModelBase): +class MixtralForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -505,7 +507,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__(config, lora_config) + super().__init__() + + self.config = config + self.lora_config = lora_config self.model = MixtralModel(config, cache_config, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 2bd40396d96e0..d288bdd9d78f5 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -59,7 +59,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from .lora_base import LoRASupportedModelBase +from .interfaces import SupportsLoRA class PhiAttention(nn.Module): @@ -231,7 +231,9 @@ def forward( return hidden_states -class PhiForCausalLM(LoRASupportedModelBase): +class PhiForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -257,7 +259,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__(config, lora_config) + super().__init__() + + self.config = config + self.lora_config = lora_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index d44aa25f4c68d..5ad7731da6b9b 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -47,7 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from .lora_base import LoRASupportedModelBase +from .interfaces import SupportsLoRA class Qwen2MLP(nn.Module): @@ -264,7 +264,9 @@ def forward( return hidden_states -class Qwen2ForCausalLM(LoRASupportedModelBase): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -306,7 +308,10 @@ def __init__( config.num_hidden_layers, )) - super().__init__(config, lora_config) + super().__init__() + + self.config = config + self.lora_config = lora_config self.quant_config = quant_config self.model = Qwen2Model(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/vlm_base.py b/vllm/model_executor/models/vlm_base.py deleted file mode 100644 index eb0aa96e50d59..0000000000000 --- a/vllm/model_executor/models/vlm_base.py +++ /dev/null @@ -1,12 +0,0 @@ -from torch import nn - -from vllm.config import VisionLanguageConfig - - -class VisionLanguageModelBase(nn.Module): - """Base class for all vision language models (VLMs).""" - - def __init__(self, vision_language_config: VisionLanguageConfig) -> None: - super().__init__() - - self.vision_language_config = vision_language_config diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 54b6ee98ada85..639c3443bc369 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -45,7 +45,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from .lora_base import LoRASupportedModelBase +from .interfaces import SupportsLoRA class XverseMLP(nn.Module): @@ -268,7 +268,9 @@ def forward( return hidden_states -class XverseForCausalLM(LoRASupportedModelBase): +class XverseForCausalLM(nn.Module, SupportsLoRA): + supports_lora = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -303,7 +305,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: - super().__init__(config, lora_config) + super().__init__() + + self.config = config + self.lora_config = lora_config self.quant_config = quant_config self.model = XverseModel(config, cache_config, quant_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d6bd9282449ab..8492ed6b0fbd6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -20,7 +20,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models.lora_base import LoRASupportedModelBase +from vllm.model_executor.models.interfaces import supports_lora from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata @@ -161,9 +161,7 @@ def load_model(self) -> None: self.model_memory_usage / float(2**30)) if self.lora_config: - assert isinstance( - self.model, - LoRASupportedModelBase), "Model does not support LoRA" + assert supports_lora(self.model), "Model does not support LoRA" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, From 5b3ee6ca97a259fea26a7c35f03196575426fe3d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 11 Jun 2024 05:40:09 +0000 Subject: [PATCH 5/6] Improve warning message --- vllm/model_executor/models/interfaces.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index cd119fc2251c5..a9eb397a5a97f 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -101,10 +101,22 @@ def supports_lora( "embedding_modules", "embedding_padding_modules", ) - if any(hasattr(model, attr) for attr in lora_attrs): - logger.warning( - "The model (%s) contains LoRA-specific attributes, " - "but does not set `supports_lora=True`.", model) + missing_attrs = tuple(attr for attr in lora_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_lora", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_lora=True`, " + "but is missing LoRA-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all LoRA-specific attributes, " + "but does not set `supports_lora=True`.", model) return result From 3180d7eac322f270bdb5873c7e7bf4a885d5ca31 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 12 Jun 2024 23:31:16 +0000 Subject: [PATCH 6/6] Apply formatter --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c42f52a064480..f1732ad46673f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -20,8 +20,8 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models.interfaces import supports_lora from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_executor.models.interfaces import supports_lora from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata