From 63c54c4c91a6dfb11d099e6b896832e5e18d2b1b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Feb 2025 21:57:51 +0800 Subject: [PATCH 01/20] add env vars Signed-off-by: youkaichao --- vllm/envs.py | 13 +++++++++++++ vllm/executor/ray_distributed_executor.py | 17 ++++++++++------- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 5018f6deb7f4f..3cd734b06ef4b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -83,6 +83,8 @@ VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True VLLM_MLA_DISABLE_REQUANTIZATION: bool = False VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False + VLLM_RAY_PER_WORKER_GPUS: float = 1.0 + VLLM_RAY_BUNDLE_INDICES: str = "" def get_default_cache_root(): @@ -539,6 +541,17 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON": lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) ), + + # Number of GPUs per worker in Ray, if it is set to be a fraction, + # it allows ray to schedule multiple actors on a single GPU. + "VLLM_RAY_PER_WORKER_GPUS": + lambda: float(os.getenv("VLLM_RAY_PER_WORKER_GPUS", "1.0")), + + # Bundle indices for Ray, if it is set, it can control precisely + # which indices are used for the Ray bundle, for every worker. + # Format: comma-separated list of integers, e.g. "0,1,2,3" + "VLLM_RAY_BUNDLE_INDICES": + lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""), } # end-env-vars-definition diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 80e7a1c405f9f..8b799e4094dd4 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -129,13 +129,7 @@ def _get_env_vars_to_be_updated(self): def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): - if (self.parallel_config.tensor_parallel_size == 1 - and self.parallel_config.pipeline_parallel_size == 1): - # For single GPU case, we use a ray worker with constrained memory. - num_gpus = self.cache_config.gpu_memory_utilization - else: - # Otherwise, the ray workers are allocated with a full GPU. - num_gpus = 1 + num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. @@ -157,10 +151,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Create the workers. driver_ip = get_ip() rank = 0 + bundle_indices: Optional[List[int]] = None + if envs.VLLM_RAY_BUNDLE_INDICES: + bundle_indices = list( + map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) + assert len(bundle_indices) == self.parallel_config.world_size, \ + ("VLLM_RAY_BUNDLE_INDICES must have the same length" + " as the world size.") worker_metadata: List[RayWorkerMetaData] = [] for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get(current_platform.ray_device_key, 0): continue + if bundle_indices is not None: + bundle_id = bundle_indices[rank] scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_capture_child_tasks=True, From 37b44cfb6f52e8e6e0c74979490c2b0b7fdbe29e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Feb 2025 22:26:56 +0800 Subject: [PATCH 02/20] add examples Signed-off-by: youkaichao --- examples/offline_inference/ray_placement.py | 83 +++++++++++++++++++++ vllm/platforms/cuda.py | 6 ++ vllm/platforms/interface.py | 5 ++ 3 files changed, 94 insertions(+) create mode 100644 examples/offline_inference/ray_placement.py diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py new file mode 100644 index 0000000000000..806225c378db5 --- /dev/null +++ b/examples/offline_inference/ray_placement.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +a simple demonstration to show how to control +the placement of the vLLM workers with Ray. +The key is to set VLLM_RAY_PER_WORKER_GPUS and +VLLM_RAY_BUNDLE_INDICES properly. +""" +import os + +import ray +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from vllm import LLM +from vllm.worker.worker import Worker + + +class MyWorker(Worker): + + def report_device_id(self) -> str: + from vllm.platforms import current_platform + return current_platform.get_device_uuid(self.device.index) + + +class MyLLM(LLM): + + def __init__(self, *args, bundle_indices: list, **kwargs): + # a hack to make the script work. + # stop ray from manipulating CUDA_VISIBLE_DEVICES + # at the top-level + del os.environ["CUDA_VISIBLE_DEVICES"] + # every worker will use 0.4 GPU, so that we can schedule + # 2 instances on the same GPUs. + os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" + os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join( + map(str, bundle_indices)) + super().__init__(*args, **kwargs) + + +# ray manages 4 GPUs +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" +ray.init() + +pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 4) +ray.get(pg_inference.ready()) +scheduling_inference = PlacementGroupSchedulingStrategy( + placement_group=pg_inference, + placement_group_capture_child_tasks=True, +) + +llms = [] + +# here we create 4 LLM instances, 2 of them will be scheduled +# on the same GPUs. +# GPUs: 0, 1, 2, 3 +# instance 0: GPU 0, 1 +# instance 1: GPU 0, 1 +# instance 2: GPU 2, 3 +# instance 3: GPU 2, 3 +for bundle_indices in [[0, 1], [0, 1], [2, 3], [2, 3]]: + llm = ray.remote( + num_cpus=0, + num_gpus=0, + scheduling_strategy=scheduling_inference, + )(MyLLM).remote( + model="facebook/opt-125m", + enforce_eager=True, + worker_cls=MyWorker, + tensor_parallel_size=2, + distributed_executor_backend="ray", + bundle_indices=bundle_indices, + ) + llms.append(llm) + +# check if the device IDs are the same for two instances +device_ids = [] +for llm in llms: + device_ids.append( + ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))) +print(device_ids) + +assert device_ids[0] == device_ids[1] +assert device_ids[2] == device_ids[3] diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b49852a727fa4..01423ca175a06 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -275,6 +275,12 @@ def get_device_name(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) return cls._get_physical_device_name(physical_device_id) + @classmethod + def get_device_uuid(cls, device_id: int = 0) -> str: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + return pynvml.nvmlDeviceGetUUID(handle) + @classmethod @lru_cache(maxsize=8) @with_nvml_context diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index dc6545c933de3..211e288b125da 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -183,6 +183,11 @@ def get_device_name(cls, device_id: int = 0) -> str: """Get the name of a device.""" raise NotImplementedError + @classmethod + def get_device_uuid(cls, device_id: int = 0) -> str: + """Get the uuid of a device, e.g. the PCI bus ID.""" + raise NotImplementedError + @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: """Get the total memory of a device in bytes.""" From 84bccc81d75f18347576d65bb6a46be7f86587f3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Feb 2025 22:34:08 +0800 Subject: [PATCH 03/20] fix ray Signed-off-by: youkaichao --- vllm/executor/ray_distributed_executor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 8b799e4094dd4..f16bc4c6b6be6 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -162,6 +162,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get(current_platform.ray_device_key, 0): continue + if rank >= self.parallel_config.world_size: + # We have created enough workers. + break if bundle_indices is not None: bundle_id = bundle_indices[rank] scheduling_strategy = PlacementGroupSchedulingStrategy( From 97826d968666993f35ada9c18b81a346f7957102 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Feb 2025 22:37:51 +0800 Subject: [PATCH 04/20] fix example Signed-off-by: youkaichao --- examples/offline_inference/ray_placement.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py index 806225c378db5..531e0d8713d92 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/ray_placement.py @@ -68,6 +68,7 @@ def __init__(self, *args, bundle_indices: list, **kwargs): worker_cls=MyWorker, tensor_parallel_size=2, distributed_executor_backend="ray", + gpu_memory_utilization=0.4, bundle_indices=bundle_indices, ) llms.append(llm) From 982080396619c8db2e45ecb3708ee8956bbceacc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Feb 2025 22:41:07 +0800 Subject: [PATCH 05/20] print more Signed-off-by: youkaichao --- examples/offline_inference/ray_placement.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py index 531e0d8713d92..28c26476aa661 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/ray_placement.py @@ -78,7 +78,7 @@ def __init__(self, *args, bundle_indices: list, **kwargs): for llm in llms: device_ids.append( ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))) -print(device_ids) +print(f"{device_ids=}") assert device_ids[0] == device_ids[1] assert device_ids[2] == device_ids[3] From 49e72d94797b0bff7ef8ba2f5ac4974042b79124 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Feb 2025 22:50:34 +0800 Subject: [PATCH 06/20] add to tests Signed-off-by: youkaichao --- .buildkite/test-pipeline.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a847a68a6ef71..08b3b28e36059 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -128,6 +128,7 @@ steps: - tests/spec_decode/e2e/test_integration_dist_tp4 - tests/compile - examples/offline_inference/rlhf.py + - examples/offline_inference/ray_placement.py commands: - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py @@ -136,6 +137,7 @@ steps: # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - python3 ../examples/offline_inference/rlhf.py + - python3 ../examples/offline_inference/ray_placement.py - label: Metrics, Tracing Test # 10min num_gpus: 2 From 9a3512f46ea55a0d186d7f073c63318ce839559d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Feb 2025 22:54:12 +0800 Subject: [PATCH 07/20] add more logging Signed-off-by: youkaichao --- examples/offline_inference/ray_placement.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py index 28c26476aa661..58d58d47ac2b5 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/ray_placement.py @@ -39,6 +39,7 @@ def __init__(self, *args, bundle_indices: list, **kwargs): # ray manages 4 GPUs os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" +os.environ["RAY_DEDUP_LOGS"] = "0" ray.init() pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 4) @@ -58,6 +59,7 @@ def __init__(self, *args, bundle_indices: list, **kwargs): # instance 2: GPU 2, 3 # instance 3: GPU 2, 3 for bundle_indices in [[0, 1], [0, 1], [2, 3], [2, 3]]: + print(f"creating LLM with bundle_indices={bundle_indices}") llm = ray.remote( num_cpus=0, num_gpus=0, From ac07519c88694a8ca9c6c9b2f01828e89ac6cb5d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Feb 2025 22:55:14 +0800 Subject: [PATCH 08/20] add more logging Signed-off-by: youkaichao --- examples/offline_inference/ray_placement.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py index 58d58d47ac2b5..d33602519a59f 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/ray_placement.py @@ -44,6 +44,8 @@ def __init__(self, *args, bundle_indices: list, **kwargs): pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 4) ray.get(pg_inference.ready()) +print(f"placement group has bundles {pg_inference.bundle_specs=}") + scheduling_inference = PlacementGroupSchedulingStrategy( placement_group=pg_inference, placement_group_capture_child_tasks=True, From faa7ddae90f6ac01e5ef3c2bc5001628213a4898 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Feb 2025 22:56:34 +0800 Subject: [PATCH 09/20] add more logging Signed-off-by: youkaichao --- examples/offline_inference/ray_placement.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py index d33602519a59f..77eba7e5aceb1 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/ray_placement.py @@ -34,6 +34,7 @@ def __init__(self, *args, bundle_indices: list, **kwargs): os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join( map(str, bundle_indices)) + print(f"creating LLM with bundle_indices={bundle_indices}") super().__init__(*args, **kwargs) @@ -61,7 +62,6 @@ def __init__(self, *args, bundle_indices: list, **kwargs): # instance 2: GPU 2, 3 # instance 3: GPU 2, 3 for bundle_indices in [[0, 1], [0, 1], [2, 3], [2, 3]]: - print(f"creating LLM with bundle_indices={bundle_indices}") llm = ray.remote( num_cpus=0, num_gpus=0, From bf85042564a29fe592521c30c026e034e95bf05b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Feb 2025 22:59:11 +0800 Subject: [PATCH 10/20] add more logging Signed-off-by: youkaichao --- .buildkite/test-pipeline.yaml | 2 +- examples/offline_inference/ray_placement.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 08b3b28e36059..7ef40564c5bd2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -137,7 +137,7 @@ steps: # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - python3 ../examples/offline_inference/rlhf.py - - python3 ../examples/offline_inference/ray_placement.py + - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py - label: Metrics, Tracing Test # 10min num_gpus: 2 diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py index 77eba7e5aceb1..5cdc791b16b25 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/ray_placement.py @@ -40,7 +40,6 @@ def __init__(self, *args, bundle_indices: list, **kwargs): # ray manages 4 GPUs os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" -os.environ["RAY_DEDUP_LOGS"] = "0" ray.init() pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 4) From 31ab75cd228f8a57bec3cd076be20bfc12d578dc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 5 Feb 2025 00:27:15 +0800 Subject: [PATCH 11/20] update tests Signed-off-by: youkaichao --- examples/offline_inference/ray_placement.py | 75 ++++++++++++++------- 1 file changed, 49 insertions(+), 26 deletions(-) diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py index 5cdc791b16b25..5493303c30c8c 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/ray_placement.py @@ -38,33 +38,55 @@ def __init__(self, *args, bundle_indices: list, **kwargs): super().__init__(*args, **kwargs) +class RayTrainingActor: + + def report_device_id(self) -> str: + from vllm.platforms import current_platform + return current_platform.get_device_uuid(0) + + # ray manages 4 GPUs os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" ray.init() +# we want to co-locate vLLM instance and the training actor +# on the same set of GPUs. +# the placement plan is as follows: +# GPU 0 and 1: training actor 0, 1, and vLLM instance 0 (with TP=2) +# GPU 2 and 3: training actor 2, 3, and vLLM instance 1 (with TP=2) + pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 4) ray.get(pg_inference.ready()) print(f"placement group has bundles {pg_inference.bundle_specs=}") -scheduling_inference = PlacementGroupSchedulingStrategy( - placement_group=pg_inference, - placement_group_capture_child_tasks=True, -) - -llms = [] - -# here we create 4 LLM instances, 2 of them will be scheduled -# on the same GPUs. -# GPUs: 0, 1, 2, 3 -# instance 0: GPU 0, 1 -# instance 1: GPU 0, 1 -# instance 2: GPU 2, 3 -# instance 3: GPU 2, 3 -for bundle_indices in [[0, 1], [0, 1], [2, 3], [2, 3]]: +training_actors = [] +training_actor_device_ids = [] +inference_engines = [] +inference_engine_device_ids = [] + +for bundle_index in [0, 1, 2, 3]: + training_actor = ray.remote( + num_cpus=0, + num_gpus=0.4, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg_inference, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_index, + ), + )(RayTrainingActor).remote() + training_actors.append(training_actor) + device_id = ray.get(training_actor.report_device_id.remote()) + print(f"training actor {bundle_index} is on {device_id}") + training_actor_device_ids.append(device_id) + +for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]): llm = ray.remote( num_cpus=0, num_gpus=0, - scheduling_strategy=scheduling_inference, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg_inference, + placement_group_capture_child_tasks=True, + ), )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, @@ -74,14 +96,15 @@ def __init__(self, *args, bundle_indices: list, **kwargs): gpu_memory_utilization=0.4, bundle_indices=bundle_indices, ) - llms.append(llm) - -# check if the device IDs are the same for two instances -device_ids = [] -for llm in llms: - device_ids.append( + inference_engines.append(llm) + inference_engine_device_ids.append( ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))) -print(f"{device_ids=}") - -assert device_ids[0] == device_ids[1] -assert device_ids[2] == device_ids[3] + print(f"inference engine {i} is on {inference_engine_device_ids[-1]}") + +# check the placement +# the first two training actors should be +# on the same GPUs as the first inference engine +assert training_actor_device_ids[:2] == inference_engine_device_ids[0] +# the last two training actors should be +# on the same GPUs as the second inference engine +assert training_actor_device_ids[2:] == inference_engine_device_ids[1] From 86fa368474ec476214c160506c461044516e3427 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 5 Feb 2025 00:29:46 +0800 Subject: [PATCH 12/20] add comments Signed-off-by: youkaichao --- examples/offline_inference/ray_placement.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py index 5493303c30c8c..990241aa74838 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/ray_placement.py @@ -80,6 +80,10 @@ def report_device_id(self) -> str: training_actor_device_ids.append(device_id) for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]): + # IMPORTANT: when creating vLLM instances, we need to + # make sure there are no GPU activities on the target GPUs, + # otherwise, they will interfere with the vLLM memory profiling, + # and cause unexpected behaviors. llm = ray.remote( num_cpus=0, num_gpus=0, From bb33f83dcce18cb6ecf1e259617fe215cc5448e7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 5 Feb 2025 00:40:32 +0800 Subject: [PATCH 13/20] add decorator Signed-off-by: youkaichao --- vllm/platforms/cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 01423ca175a06..991d55ac861a4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -276,6 +276,8 @@ def get_device_name(cls, device_id: int = 0) -> str: return cls._get_physical_device_name(physical_device_id) @classmethod + @lru_cache(maxsize=8) + @with_nvml_context def get_device_uuid(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) From 517c162cbb4810cce21b5ced935820cf3c2c1822 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 5 Feb 2025 00:43:10 +0800 Subject: [PATCH 14/20] add comments Signed-off-by: youkaichao --- examples/offline_inference/ray_placement.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py index 990241aa74838..836ac47150ebb 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/ray_placement.py @@ -41,6 +41,9 @@ def __init__(self, *args, bundle_indices: list, **kwargs): class RayTrainingActor: def report_device_id(self) -> str: + # the argument for get_device_uuid is the index + # of the GPU in the visible devices. + # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs from vllm.platforms import current_platform return current_platform.get_device_uuid(0) From eab93049c4aaf7b0ed98b475da237ce2cbceb19a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 5 Feb 2025 00:47:41 +0800 Subject: [PATCH 15/20] add comments Signed-off-by: youkaichao --- examples/offline_inference/ray_placement.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py index 836ac47150ebb..67dec153b7b21 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/ray_placement.py @@ -104,6 +104,10 @@ def report_device_id(self) -> str: bundle_indices=bundle_indices, ) inference_engines.append(llm) + # don't call any method on the inference engine here, + # otherwise it will block until the vLLM instance is created. + +for i, llm in enumerate(inference_engines): inference_engine_device_ids.append( ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))) print(f"inference engine {i} is on {inference_engine_device_ids[-1]}") From 1cec2f819d0d8d8ed69c2d7e0307196c80b1d8f2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 5 Feb 2025 21:02:26 +0800 Subject: [PATCH 16/20] rename Signed-off-by: youkaichao --- examples/offline_inference/ray_placement.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/ray_placement.py index 67dec153b7b21..cd801a3c0c858 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/ray_placement.py @@ -58,9 +58,9 @@ def report_device_id(self) -> str: # GPU 0 and 1: training actor 0, 1, and vLLM instance 0 (with TP=2) # GPU 2 and 3: training actor 2, 3, and vLLM instance 1 (with TP=2) -pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 4) -ray.get(pg_inference.ready()) -print(f"placement group has bundles {pg_inference.bundle_specs=}") +pg = placement_group([{"GPU": 1, "CPU": 0}] * 4) +ray.get(pg.ready()) +print(f"placement group has bundles {pg.bundle_specs=}") training_actors = [] training_actor_device_ids = [] @@ -72,7 +72,7 @@ def report_device_id(self) -> str: num_cpus=0, num_gpus=0.4, scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg_inference, + placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=bundle_index, ), @@ -91,7 +91,7 @@ def report_device_id(self) -> str: num_cpus=0, num_gpus=0, scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg_inference, + placement_group=pg, placement_group_capture_child_tasks=True, ), )(MyLLM).remote( From 98abef38be79ee63d93b0c0fa9f0fe15d6719965 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 5 Feb 2025 21:06:27 +0800 Subject: [PATCH 17/20] add asserts Signed-off-by: youkaichao --- vllm/executor/ray_distributed_executor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index f16bc4c6b6be6..82fc5caab79ec 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -157,7 +157,11 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) assert len(bundle_indices) == self.parallel_config.world_size, \ ("VLLM_RAY_BUNDLE_INDICES must have the same length" - " as the world size.") + f" as the world size, but got {bundle_indices=} " + f"and {self.parallel_config.world_size=}") + assert len(set(bundle_indices)) == len(bundle_indices), \ + ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values," + f" but got {bundle_indices=}") worker_metadata: List[RayWorkerMetaData] = [] for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get(current_platform.ray_device_key, 0): From 8957091d84ab366cb6e71d33d32e384d46c318ed Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 5 Feb 2025 21:06:49 +0800 Subject: [PATCH 18/20] add asserts Signed-off-by: youkaichao --- vllm/executor/ray_distributed_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 82fc5caab79ec..6bfd5076a5032 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -156,7 +156,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", bundle_indices = list( map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) assert len(bundle_indices) == self.parallel_config.world_size, \ - ("VLLM_RAY_BUNDLE_INDICES must have the same length" + ("VLLM_RAY_BUNDLE_INDICES must have the same size" f" as the world size, but got {bundle_indices=} " f"and {self.parallel_config.world_size=}") assert len(set(bundle_indices)) == len(bundle_indices), \ From 5cf14b761b76d2aad71e927c7c992321001bcc5c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 5 Feb 2025 21:08:52 +0800 Subject: [PATCH 19/20] add comments Signed-off-by: youkaichao --- vllm/envs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index b1fcdec6e25b8..745b068b7a458 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -553,7 +553,8 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: ), # Number of GPUs per worker in Ray, if it is set to be a fraction, - # it allows ray to schedule multiple actors on a single GPU. + # it allows ray to schedule multiple actors on a single GPU, + # so that users can colocate other actors on the same GPUs as vLLM. "VLLM_RAY_PER_WORKER_GPUS": lambda: float(os.getenv("VLLM_RAY_PER_WORKER_GPUS", "1.0")), From 91bb1465f6a3a04f2e29ae7f8b1deea5371e9716 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 6 Feb 2025 00:55:53 +0800 Subject: [PATCH 20/20] unify bundle_indices Signed-off-by: youkaichao --- vllm/executor/ray_distributed_executor.py | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 6bfd5076a5032..6a25a4d50fb98 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -149,10 +149,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) # Create the workers. - driver_ip = get_ip() - rank = 0 - bundle_indices: Optional[List[int]] = None + bundle_indices: List[int] if envs.VLLM_RAY_BUNDLE_INDICES: + # Use the bundle indices specified by the user. bundle_indices = list( map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) assert len(bundle_indices) == self.parallel_config.world_size, \ @@ -162,15 +161,17 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", assert len(set(bundle_indices)) == len(bundle_indices), \ ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values," f" but got {bundle_indices=}") + else: + # use the first N bundles that have GPU resources. + bundle_indices = [] + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if bundle.get(current_platform.ray_device_key, 0): + bundle_indices.append(bundle_id) + bundle_indices = bundle_indices[:self.parallel_config.world_size] + worker_metadata: List[RayWorkerMetaData] = [] - for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if not bundle.get(current_platform.ray_device_key, 0): - continue - if rank >= self.parallel_config.world_size: - # We have created enough workers. - break - if bundle_indices is not None: - bundle_id = bundle_indices[rank] + driver_ip = get_ip() + for rank, bundle_id in enumerate(bundle_indices): scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_capture_child_tasks=True, @@ -197,7 +198,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", rpc_rank=rank) worker_metadata.append( RayWorkerMetaData(worker=worker, created_rank=rank)) - rank += 1 worker_ips = ray.get([ each.worker.get_node_ip.remote() # type: ignore[attr-defined]