From eedc12e12ed3a4ecf9cc8c6648d9e2bc2caffc23 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sun, 21 Jul 2024 03:09:29 -0700 Subject: [PATCH] Support Deepseek MoE Model (#689) --- .../managers/controller/cuda_graph_runner.py | 57 ++- .../srt/managers/controller/model_runner.py | 7 +- python/sglang/srt/models/deepseek.py | 430 ++++++++++++++++++ python/sglang/srt/server.py | 2 +- python/sglang/srt/utils.py | 46 ++ 5 files changed, 519 insertions(+), 23 deletions(-) create mode 100644 python/sglang/srt/models/deepseek.py diff --git a/python/sglang/srt/managers/controller/cuda_graph_runner.py b/python/sglang/srt/managers/controller/cuda_graph_runner.py index 0066f92b822..1095481ee69 100644 --- a/python/sglang/srt/managers/controller/cuda_graph_runner.py +++ b/python/sglang/srt/managers/controller/cuda_graph_runner.py @@ -1,6 +1,7 @@ """Run the model with cuda graph.""" import bisect +from contextlib import contextmanager import torch from flashinfer import BatchDecodeWithPagedKVCacheWrapper @@ -15,9 +16,10 @@ InputMetadata, init_flashinfer_args, ) +from sglang.srt.utils import monkey_patch_vllm_all_gather -def _to_torch(model: torch.nn.Module, reverse=False): +def _to_torch(model: torch.nn.Module, reverse: bool = False): for sub in model._modules.values(): if isinstance(sub, CustomOp): if reverse: @@ -28,13 +30,26 @@ def _to_torch(model: torch.nn.Module, reverse=False): _to_torch(sub, reverse) -def get_forward(model: torch.nn.Module, use_torch: bool): - if use_torch: - _to_torch(model, reverse=False) - return torch.compile(model.forward, mode="max-autotune-no-cudagraphs") - else: - _to_torch(model, reverse=True) - return model.forward +@contextmanager +def patch_model( + model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator" +): + backup_ca_comm = None + + try: + if use_compile: + _to_torch(model) + monkey_patch_vllm_all_gather() + backup_ca_comm = tp_group.ca_comm + tp_group.ca_comm = None + yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs") + else: + yield model.forward + finally: + if use_compile: + _to_torch(model, reverse=True) + monkey_patch_vllm_all_gather(reverse=True) + tp_group.ca_comm = backup_ca_comm class CudaGraphRunner: @@ -86,17 +101,21 @@ def capture(self, batch_size_list): with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream for bs in batch_size_list: - forward = get_forward(self.model_runner.model, bs in self.compile_bs) - ( - graph, - input_buffers, - output_buffers, - flashinfer_handler, - ) = self.capture_one_batch_size(bs, forward) - self.graphs[bs] = graph - self.input_buffers[bs] = input_buffers - self.output_buffers[bs] = output_buffers - self.flashinfer_handlers[bs] = flashinfer_handler + with patch_model( + self.model_runner.model, + bs in self.compile_bs, + self.model_runner.tp_group, + ) as forward: + ( + graph, + input_buffers, + output_buffers, + flashinfer_handler, + ) = self.capture_one_batch_size(bs, forward) + self.graphs[bs] = graph + self.input_buffers[bs] = input_buffers + self.output_buffers[bs] = output_buffers + self.flashinfer_handlers[bs] = flashinfer_handler def capture_one_batch_size(self, bs, forward): graph = torch.cuda.CUDAGraph() diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 9fd0f19a3da..b5a7c06163c 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -22,7 +22,6 @@ init_distributed_environment, initialize_model_parallel, ) -from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config @@ -241,7 +240,9 @@ def init_cuda_graphs(self): self.cuda_graph_runner = None return - logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.") + logger.info( + f"[gpu_id={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes." + ) batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)] self.cuda_graph_runner = CudaGraphRunner( self, @@ -252,7 +253,7 @@ def init_cuda_graphs(self): self.cuda_graph_runner.capture(batch_size_list) except RuntimeError as e: raise Exception( - f"Capture cuda graph failed {e}. Possible solutions:\n" + f"Capture cuda graph failed: {e}. Possible solutions:\n" f"1. disable cuda graph by --disable-cuda-graph\n" f"2. set --mem-fraction-static to a smaller value\n" f"Open an issue on GitHub with reproducible scripts if you need help.\n" diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py new file mode 100644 index 00000000000..c266a2ea4f1 --- /dev/null +++ b/python/sglang/srt/models/deepseek.py @@ -0,0 +1,430 @@ +# Adapted from: +# https://github.com/vllm-project/vllm/blob/14f91fe67c2342f2fe859dc6a5c40810df0e1c61/vllm/model_executor/models/deepseek.py +"""Inference-only Deepseek model.""" +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig +from vllm.config import CacheConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.controller.infer_batch import InputMetadata + + +class DeepseekMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.n_routed_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.n_routed_experts}." + ) + + self.experts = nn.ModuleList( + [ + DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + for idx in range(self.n_routed_experts) + ] + ) + self.pack_params() + + self.gate = ReplicatedLinear( + config.hidden_size, self.n_routed_experts, bias=False, quant_config=None + ) + + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + + def pack_params(self): + w1 = [] + w2 = [] + for expert in self.experts: + w1.append(expert.gate_up_proj.weight) + w2.append(expert.down_proj.weight) + self.w1 = torch._utils._flatten_dense_tensors(w1) + w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) + for data, param in zip(w1s, w1): + param.data = data + self.w1 = self.w1.view(len(w1), *w1s[0].shape) + + self.w2 = torch._utils._flatten_dense_tensors(w2) + w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) + for data, param in zip(w2s, w2): + param.data = data + + self.w2 = self.w2.view(len(w2), *w2s[0].shape) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.config.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe( + hidden_states, + self.w1, + self.w2, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + inplace=True, + ) + + if self.config.n_shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class DeepseekAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_id: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = DeepseekAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + self.mlp = DeepseekMoE(config=config, quant_config=quant_config) + else: + self.mlp = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DeepseekModel(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + DeepseekDecoderLayer( + config, layer_id, cache_config, quant_config=quant_config + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, input_metadata, residual + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = DeepseekModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, input_metadata) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ( + "mlp.experts." in name or "mlp.shared_experts." in name + ) and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ( + "mlp.experts." in name or "mlp.shared_experts." in name + ) and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = DeepseekForCausalLM diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 3c6a79e305c..b3e0aea583c 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -167,7 +167,7 @@ def _set_torch_compile_config(): torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future # FIXME: tmp workaround - torch._dynamo.config.accumulated_cache_size_limit = 128 + torch._dynamo.config.accumulated_cache_size_limit = 256 def launch_server( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index dfaf51a9356..8aaf5c3fbf9 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -411,6 +411,52 @@ def load_model( setattr(DummyModelLoader, "load_model", load_model) +vllm_all_gather_backup = None + + +def monkey_patch_vllm_all_gather(reverse: bool = False): + """Monkey patch all-gather to remove in-place operations.""" + from torch.distributed import _functional_collectives as funcol + from vllm.distributed.parallel_state import GroupCoordinator + + global vllm_all_gather_backup + if vllm_all_gather_backup is None: + vllm_all_gather_backup = GroupCoordinator.all_gather + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty( + (world_size,) + input_size, dtype=input_.dtype, device=input_.device + ) + + output_tensor = funcol.all_gather_tensor( + input_, gather_dim=0, group=self.device_group + ).view((world_size,) + input_size) + + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape( + input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] + ) + return output_tensor + + if reverse: + setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup) + else: + setattr(GroupCoordinator, "all_gather", all_gather) + + API_KEY_HEADER_NAME = "X-API-Key"