Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Concurrently Poll Ray Driver and Worker Results to Avoid Distributed Init Deadlock #7159

Closed
63 changes: 46 additions & 17 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import os
from collections import defaultdict
from concurrent import futures
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -239,7 +240,10 @@ def sort_by_driver_then_worker_ip(worker):
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

self._run_workers("init_device")
# must run driver in background thread if len(workers) > 0 to avoid
# distributed init deadlock.
# (https://github.com/vllm-project/vllm/pull/7159)
self._run_workers("init_device", run_driver_in_background_thread=True)
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
Expand Down Expand Up @@ -309,6 +313,7 @@ def _run_workers(
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
run_driver_in_background_thread: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
Expand Down Expand Up @@ -358,33 +363,57 @@ def _run_workers(
# Just return futures
return ray_worker_outputs

driver_worker_output = []
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]

# Start the driver worker after all the ray workers.
# Start the driver worker task after all the ray workers'.
if not use_dummy_driver:
driver_worker_output = [
self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
]
# Driver task will run in this python process
if run_driver_in_background_thread and ray_worker_outputs:
# Poll driver and worker tasks concurrently in background
# threads.
#
# This can avoid deadlock if the driver task is
# blocking on some out of band comm (e.g. torch.dist.init)
# that is invalidated by a Ray worker exception.
#
# See: https://github.com/vllm-project/vllm/issues/3455

with futures.ThreadPoolExecutor(max_workers=2) as executor:
driver_poll_thread = executor.submit(
self.driver_worker.execute_method, method,
*driver_args, **driver_kwargs)
worker_poll_thread = executor.submit(
ray.get, ray_worker_outputs)

for completed_future in futures.as_completed(
[driver_poll_thread, worker_poll_thread]):
# Will raise exception if underlying thread raises
res = completed_future.result()
if not isinstance(res, list):
driver_output = [res]
else:
worker_outputs = res
all_worker_outputs = driver_output + worker_outputs
else:
driver_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
all_worker_outputs = [driver_output
] + ray.get(ray_worker_outputs)
else:
assert self.driver_dummy_worker is not None
driver_worker_output = [
ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
]

# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
driver_output = self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs)
all_worker_outputs = ray.get([driver_output] +
ray_worker_outputs)
else:
all_worker_outputs = ray.get(ray_worker_outputs)

return driver_worker_output + ray_worker_outputs
return all_worker_outputs

def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
Expand Down
Loading