From dcfc90ba965d69cb29b800b035fd87020d71d2f6 Mon Sep 17 00:00:00 2001 From: zhaoyang-star Date: Mon, 29 Jan 2024 14:47:39 +0800 Subject: [PATCH] Fix error when tp > 1 (#2644) Co-authored-by: zhaoyang-star --- vllm/engine/llm_engine.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 171a9081644ee..2539395df161c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -237,7 +237,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", model_config = copy.deepcopy(self.model_config) parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) - cache_config = copy.deepcopy(self.cache_config) for rank, (worker, (node_id, _)) in enumerate(zip(self.workers, @@ -253,7 +252,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", rank, distributed_init_method, lora_config=self.lora_config, - cache_config=cache_config, + kv_cache_dtype=self.cache_config.cache_dtype, )) driver_rank = 0 @@ -266,7 +265,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", driver_rank, distributed_init_method, lora_config=self.lora_config, - cache_config=cache_config, + kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, )