Skip to content

Commit

Permalink
[Hardware][Intel GPU] Add intel GPU pipeline parallel support. (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
jikunshang authored and triple-Mu committed Sep 4, 2024
1 parent 546cf34 commit b166414
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 19 deletions.
5 changes: 5 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,11 @@ def _get_executor_cls(
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
else:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,13 @@ def _get_executor_cls(cls,
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutor
executor_class = RayXPUExecutor
elif distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, Please try ray instead.")
else:
from vllm.executor.xpu_executor import XPUExecutor
executor_class = XPUExecutor
Expand Down
38 changes: 22 additions & 16 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
uses_ray: bool = False

def _init_executor(self) -> None:
self._check_executor_parameters()

# Create the parallel GPU workers.
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size

# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})

# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()

Expand Down Expand Up @@ -68,16 +64,6 @@ def _init_executor(self) -> None:
if world_size > 1:
maybe_set_triton_cache_manager()

cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local gpu count ({cuda_device_count})")

assert world_size <= cuda_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")

# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
Expand Down Expand Up @@ -139,6 +125,26 @@ def shutdown(signum, frame):
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)

def _check_executor_parameters(self):
world_size = self.parallel_config.tensor_parallel_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size

# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})

cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local gpu count ({cuda_device_count})")

assert world_size <= cuda_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")

def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None:
Expand Down
26 changes: 26 additions & 0 deletions vllm/executor/multiproc_xpu_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import vllm.envs as envs
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync)
from vllm.executor.xpu_executor import XPUExecutor
from vllm.logger import init_logger
from vllm.utils import make_async

logger = init_logger(__name__)


class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor):
"""Python multiprocessing-based multi-XPU executor"""

def _check_executor_parameters(self):
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
if mp_method != "spawn":
raise RuntimeError(
"XPU multiprocess executor only support spawn as mp method")


class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor,
MultiprocessingGPUExecutorAsync):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
19 changes: 16 additions & 3 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
Expand Down Expand Up @@ -439,9 +440,11 @@ def profile_run(self) -> None:
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1

batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len

seq_data, dummy_multi_modal_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
Expand All @@ -465,7 +468,13 @@ def profile_run(self) -> None:
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
self.execute_model(model_input, kv_caches)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.xpu.synchronize()
return

Expand Down Expand Up @@ -537,20 +546,24 @@ def execute_model(
and self.observability_config.collect_model_forward_time):
model_forward_start_time = time.time()

hidden_states = model_executable(
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device))
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states

if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end_time = time.time()

# Compute the logits.
logits = self.model.compute_logits(hidden_states,
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)

# Only perform sampling in the driver worker.
Expand Down
6 changes: 6 additions & 0 deletions vllm/worker/xpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.utils import is_xpu
Expand Down Expand Up @@ -198,3 +199,8 @@ def init_worker_distributed_environment(self) -> None:
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

if parallel_config.pipeline_parallel_size > 1:
# torch-ccl xpu need a collective API warm up
# before calling send/recv API
get_pp_group().all_reduce(torch.zeros(1).xpu())

0 comments on commit b166414

Please sign in to comment.