From 435e3b28c9cd23faaa165d50feffaebb31c6f2c5 Mon Sep 17 00:00:00 2001 From: shangmingc Date: Mon, 20 Jan 2025 10:56:43 +0800 Subject: [PATCH] [Bugfix] Fix num_heads value for simple connector when tp enabled (#12074) Signed-off-by: Shangming Cai --- vllm/distributed/kv_transfer/kv_connector/simple_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 4ace03ff1184e..7780e2dfa317d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -35,6 +35,7 @@ def __init__( ): self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size if self.config.kv_connector == "PyNcclConnector": from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( @@ -161,7 +162,7 @@ def send_kv_caches_and_hidden_states( end_layer = model_executable.model.end_layer model_config = model_executable.model.config - num_heads = model_config.num_key_value_heads + num_heads = int(model_config.num_key_value_heads / self.tp_size) hidden_size = model_config.hidden_size num_attention_heads = model_config.num_attention_heads head_size = int(hidden_size / num_attention_heads)