diff --git a/python/pyproject.toml b/python/pyproject.toml index 5e144f809e3..b5fa4ceead5 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu "orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2", "outlines>=0.0.44,<0.1.0", "modelscope"] -srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"] +srt = ["sglang[runtime_common]", "torch", "vllm==0.6.4.post1"] # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 94d48e82b9d..d31dc81ed50 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -38,6 +38,7 @@ logger = logging.getLogger(__name__) +@CustomOp.register("silu_and_mul") class SiluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 @@ -51,6 +52,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return out +@CustomOp.register("gelu_and_mul") class GeluAndMul(CustomOp): def __init__(self, approximate="tanh"): super().__init__() diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 3ae392eb9af..3ffa91575c8 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -36,6 +36,7 @@ logger = logging.getLogger(__name__) +@CustomOp.register("rmsnorm") class RMSNorm(CustomOp): def __init__( self, @@ -78,6 +79,7 @@ def forward_native( return x, residual +@CustomOp.register("gemma_rmsnorm") class GemmaRMSNorm(CustomOp): def __init__( self, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5cde1e942ff..8bef7d187b0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -28,6 +28,7 @@ import torch.nn as nn from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig +from vllm.config import VllmConfig from vllm.distributed import ( get_tp_group, init_distributed_environment, @@ -59,6 +60,7 @@ enable_show_time_cost, get_available_gpu_memory, monkey_patch_vllm_dummy_weight_loader, + monkey_patch_vllm_model_config, monkey_patch_vllm_p2p_access_check, ) @@ -243,12 +245,14 @@ def load_model(self): # Prepare the vllm model config monkey_patch_vllm_dummy_weight_loader() + monkey_patch_vllm_model_config() self.load_config = LoadConfig( load_format=self.server_args.load_format, download_dir=self.server_args.download_dir, ) self.vllm_model_config = VllmModelConfig( model=self.server_args.model_path, + task="generate" if self.model_config.is_generation else "embedding", quantization=self.server_args.quantization, tokenizer=None, tokenizer_mode=None, @@ -263,15 +267,14 @@ def load_model(self): ) self.dtype = self.vllm_model_config.dtype + self.vllm_config = VllmConfig() + self.vllm_config.model_config = self.vllm_model_config + self.vllm_config.load_config = self.load_config + self.vllm_config.device_config = DeviceConfig(self.device) + # Load the model self.model = get_model( - model_config=self.vllm_model_config, - load_config=self.load_config, - device_config=DeviceConfig(self.device), - parallel_config=None, - scheduler_config=None, - lora_config=None, - cache_config=None, + vllm_config=self.vllm_config, ) self.sliding_window_size = ( self.model.get_attention_sliding_window_size() diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e04ec7ddffa..13e4b4234e0 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -410,37 +410,23 @@ def monkey_patch_vllm_dummy_weight_loader(): Monkey patch the dummy weight loader in vllm to call process_weights_after_loading. """ + from vllm.config import VllmConfig from vllm.model_executor.model_loader.loader import ( - CacheConfig, - DeviceConfig, DummyModelLoader, - LoRAConfig, - ModelConfig, - ParallelConfig, - SchedulerConfig, _initialize_model, initialize_dummy_weights, nn, set_default_torch_dtype, ) - def load_model( - self, - *, - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - ) -> nn.Module: - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: + with set_default_torch_dtype(vllm_config.model_config.dtype): + with torch.device(vllm_config.device_config.device): model = _initialize_model( - model_config, + vllm_config.model_config, self.load_config, - lora_config, - cache_config, + vllm_config.lora_config, + vllm_config.cache_config, ) for _, module in model.named_modules(): @@ -512,6 +498,60 @@ def maybe_set_triton_cache_manager() -> None: os.environ["TRITON_CACHE_MANAGER"] = manager +def monkey_patch_vllm_model_config(): + from typing import Dict, Set, Tuple, Union + + from transformers import PretrainedConfig + from vllm.config import ModelConfig, TaskOption, _Task + + def _resolve_task( + self, + task_option: Union[TaskOption, _Task], + hf_config: PretrainedConfig, + ) -> Tuple[Set[_Task], _Task]: + + architectures = getattr(hf_config, "architectures", []) + if isinstance(architectures, str): + architectures = [architectures] + + non_generation_models = { + "LlamaEmbeddingModel", + "MistralModel", + "LlamaForSequenceClassification", + "LlamaForSequenceClassificationWithNormal_Weights", + "InternLM2ForRewardModel", + } + + is_generation = not any(arch in non_generation_models for arch in architectures) + + auto_map = getattr(hf_config, "auto_map", {}) + has_sequence_classification = any( + "ForSequenceClassification" in v for v in auto_map.values() + ) + + task_support: Dict[_Task, bool] = { + "generate": is_generation, + "embedding": (not is_generation) or has_sequence_classification, + } + + supported_tasks_lst = [ + task for task, is_supported in task_support.items() if is_supported + ] + supported_tasks = set(supported_tasks_lst) + + if task_option not in supported_tasks: + msg = ( + f"This model does not support the '{task_option}' task. " + f"Supported tasks: {supported_tasks}" + ) + raise ValueError(msg) + selected_task = task_option + + return supported_tasks, selected_task + + setattr(ModelConfig, "_resolve_task", _resolve_task) + + class CustomCacheManager(FileCacheManager): # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py def __init__(self, key, override=False, dump=False):