diff --git a/docs/source/serving/distributed_serving.rst b/docs/source/serving/distributed_serving.rst index 5f14fd2b0ee0a..fcb2646df50d3 100644 --- a/docs/source/serving/distributed_serving.rst +++ b/docs/source/serving/distributed_serving.rst @@ -50,7 +50,7 @@ You can also additionally specify :code:`--pipeline-parallel-size` to enable pip $ --pipeline-parallel-size 2 .. note:: - Pipeline parallel is a beta feature. It is only supported for online serving as well as LLaMa, GPT2, and Mixtral style models. + Pipeline parallel is a beta feature. It is only supported for online serving as well as LLaMa, GPT2, Mixtral, Qwen, Qwen2, and Nemotron style models. Multi-Node Inference and Serving -------------------------------- diff --git a/vllm/config.py b/vllm/config.py index e065744592378..ef56e2b6395be 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -42,6 +42,7 @@ "NemotronForCausalLM", "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", + "QWenLMHeadModel", ] diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 47c85c783db7a..eb61adf34e9a7 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -12,7 +12,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -30,6 +30,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once +from .utils import is_pp_missing_parameter, make_layers + class QWenMLP(nn.Module): @@ -186,6 +188,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -195,10 +198,10 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.h = nn.ModuleList([ - QWenBlock(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: QWenBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( @@ -207,18 +210,29 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], ) -> torch.Tensor: - hidden_states = self.wte(input_ids) - residual = None - for i in range(len(self.h)): + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.h[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states @@ -250,9 +264,23 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, @@ -284,6 +312,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -301,6 +332,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): "Only text inputs are allowed. Images won't be handled " "until Qwen-VL models are fully supported.") continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)