diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index dfc7143823d5a..43d2c88d3b9ca 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -15,7 +15,7 @@ # limitations under the License. """Wrapper around `transformers` models""" import re -from typing import Iterable, Optional, Union +from typing import Iterable, Literal, Optional, Union import torch from torch import nn @@ -72,15 +72,24 @@ def vllm_flash_attention_forward( ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward +def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): + logger.debug("%s: %s -> %s", name, old_module, new_module) + + def replace_linear_class( linear: nn.Linear, - style: str, + style: Literal["colwise", "rowwise"], quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]: """ - In model configurations, we use a neutral type (string) to specify parallel - styles, here we use it to translate nn.Linear into vllm-style tp Linear. - - Quant config is not supported yet + Replace nn.Linear with one of vLLM's tensor parallel linear classes. + + `quant_config` is not yet supported. + Args: + linear (nn.Linear): `nn.Linear` to be replaced. + style (str): Tensor parallel style of the new linear, e.g. "colwise". + quant_config (QuantConfig): Quantization config for the new linear. + Returns: + Union[ColumnParallelLinear, RowParallelLinear]: The new linear. """ if not isinstance(style, str): @@ -93,7 +102,10 @@ def replace_linear_class( }.get(style) if vllm_linear_cls is None: - raise ValueError(f"Unsupported parallel style value: {style}") + logger.warning( + "Unsupported parallel style value: %s. " + "This layer will not be tensor parallelized.", style) + return linear class HFCompatibleLinear(vllm_linear_cls): """ @@ -119,25 +131,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() logger.info("Using Transformers backend.") - self.vllm_config = vllm_config config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - self.quant_config = quant_config + self.config = config + self.quant_config = quant_config self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size self.model: PreTrainedModel = AutoModel.from_config( self.config, attn_implementation="vllm", - torch_dtype=vllm_config.model_config.dtype, trust_remote_code=vllm_config.model_config.trust_remote_code, ) prefix = self.model.base_model_prefix # MLP modifications - self.tensor_parallelize(self.model) + self.apply_base_model_tp_plan(self.model) # Attention modifications (assumes 1 attention op per hidden layer) tp_size = get_tensor_model_parallel_world_size() @@ -170,13 +181,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config.vocab_size, logit_scale) self.sampler = get_sampler() - def log_replacement(self, name: str, old_module: nn.Module, - new_module: nn.Module): - logger.debug("%s: %s -> %s", name, old_module, new_module) - - def tensor_parallelize(self, module: nn.Module, prefix: str = ""): + def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""): + """ + Apply the base model tensor parallelization plan to a module. + Currently only supports linear layers. + """ if (self.config.base_model_tp_plan is None - and self.vllm_config.parallel_config.tensor_parallel_size > 1): + and get_tensor_model_parallel_world_size() > 1): raise ValueError( "Trying to run tensor parallelization but the model does not " "support it yet!") @@ -189,9 +200,9 @@ def tensor_parallelize(self, module: nn.Module, prefix: str = ""): new_module = replace_linear_class(child_module, style, self.quant_config) setattr(module, child_name, new_module) - self.log_replacement(qual_name, child_module, new_module) + log_replacement(qual_name, child_module, new_module) else: - self.tensor_parallelize(child_module, prefix=qual_name) + self.apply_base_model_tp_plan(child_module, prefix=qual_name) def replace_vocab_embed_class(self, module: nn.Module): # Use native set input embeddings @@ -201,8 +212,8 @@ def replace_vocab_embed_class(self, module: nn.Module): org_num_embeddings=self.config.vocab_size, quant_config=None, ) - self.log_replacement("input embedding", - self.model.get_input_embeddings(), new_module) + log_replacement("input embedding", self.model.get_input_embeddings(), + new_module) self.model.set_input_embeddings(new_module) def forward(