From 2fb0766450010932436a62ec4c252735ecc8be75 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 23 Jul 2024 13:22:09 -0600 Subject: [PATCH] [Model] Pipeline Parallel Support for DeepSeek v2 (#6519) Signed-off-by: Travis Johnson --- vllm/config.py | 1 + vllm/model_executor/models/deepseek_v2.py | 153 ++++++++++++++++------ 2 files changed, 115 insertions(+), 39 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c27d26c098b59..6e0283f8379a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -31,6 +31,7 @@ _PP_SUPPORTED_MODELS = [ "AquilaModel", "AquilaForCausalLM", + "DeepseekV2ForCausalLM", "InternLMForCausalLM", "LlamaForCausalLM", "LLaMAForCausalLM", diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 2d12ceb7f3dbf..2e3e9b6f2792e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -29,7 +29,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_world_size, +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -49,6 +50,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers + class DeepseekV2MLP(nn.Module): @@ -59,17 +62,20 @@ def __init__( hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, quant_config=quant_config, - reduce_results=reduce_results) + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -88,6 +94,7 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -112,12 +119,14 @@ def __init__( quant_config=quant_config, use_grouped_topk=True, num_expert_group=config.n_group, - topk_group=config.topk_group) + topk_group=config.topk_group, + prefix=f"{prefix}.experts") self.gate = ReplicatedLinear(config.hidden_size, config.n_routed_experts, bias=False, - quant_config=None) + quant_config=None, + prefix=f"{prefix}.gate") if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) @@ -172,10 +181,9 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - layer_idx=None, + prefix: str = "", ) -> None: super().__init__() - self.layer_idx = layer_idx self.hidden_size = hidden_size self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim @@ -195,38 +203,44 @@ def __init__( self.q_a_proj = ReplicatedLinear(self.hidden_size, self.q_lora_rank, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear(q_lora_rank, self.num_heads * self.qk_head_dim, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") else: self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.q_proj") - self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size, - self.kv_lora_rank + - self.qk_rope_head_dim, - bias=False, - quant_config=quant_config) + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") # O projection. self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.o_proj") rope_scaling['type'] = 'deepseek_yarn' self.rotary_emb = get_rope(qk_rope_head_dim, rotary_dim=qk_rope_head_dim, @@ -308,7 +322,7 @@ class DeepseekV2DecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - layer_idx: int, + prefix: str, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -318,6 +332,9 @@ def __init__( rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) self.self_attn = DeepseekV2Attention( config=config, hidden_size=self.hidden_size, @@ -333,18 +350,23 @@ def __init__( max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, - layer_idx=layer_idx, + prefix=f"{prefix}.self_attn", ) if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config) + self.mlp = DeepseekV2MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) else: self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -389,23 +411,34 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.layers = nn.ModuleList([ - DeepseekV2DecoderLayer(config, - layer_idx, - cache_config=cache_config, - quant_config=quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV2DecoderLayer( + config, + prefix, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() def forward( self, @@ -413,14 +446,28 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -436,7 +483,10 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = DeepseekV2Model(config, cache_config, quant_config) + self.model = DeepseekV2Model(config, + cache_config, + quant_config, + prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) @@ -452,7 +502,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -469,6 +519,20 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -504,6 +568,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -514,6 +582,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -527,6 +599,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)