Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Intel GPU] Add intel GPU pipeline parallel support. #7810

Merged
merged 4 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,11 @@ def _get_executor_cls(
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
else:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,13 @@ def _get_executor_cls(cls,
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutor
executor_class = RayXPUExecutor
elif distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, Please try ray instead.")
else:
from vllm.executor.xpu_executor import XPUExecutor
executor_class = XPUExecutor
Expand Down
38 changes: 22 additions & 16 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
uses_ray: bool = False

def _init_executor(self) -> None:
self._check_executor_parameters()

# Create the parallel GPU workers.
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size

# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})

# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()

Expand Down Expand Up @@ -68,16 +64,6 @@ def _init_executor(self) -> None:
if world_size > 1:
maybe_set_triton_cache_manager()

cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local gpu count ({cuda_device_count})")

assert world_size <= cuda_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")

# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
Expand Down Expand Up @@ -139,6 +125,26 @@ def shutdown(signum, frame):
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)

def _check_executor_parameters(self):
world_size = self.parallel_config.tensor_parallel_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size

# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
update_environment_variables({
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})

cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local gpu count ({cuda_device_count})")

assert world_size <= cuda_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")

def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor",
None)) is not None:
Expand Down
26 changes: 26 additions & 0 deletions vllm/executor/multiproc_xpu_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import vllm.envs as envs
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync)
from vllm.executor.xpu_executor import XPUExecutor
from vllm.logger import init_logger
from vllm.utils import make_async

logger = init_logger(__name__)


class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor):
"""Python multiprocessing-based multi-XPU executor"""

def _check_executor_parameters(self):
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
if mp_method != "spawn":
raise RuntimeError(
"XPU multiprocess executor only support spawn as mp method")
Comment on lines +15 to +18
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this the case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you initialize the gpu somewhere? usually this needs to be avoided, and should already be avoided in vllm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"spawn" will not work when users run LLM class directly, without if __name__ == "__main__"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you import intel_extension_for_pytorch, it will call xpu initialization implicitly. and will fall into native runtime. I guess it will detect whether the process is started via fork or spawn.
What do you mean "spawn" will not work when users run LLM class directly here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #5637 for example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, got your point. I just tried with offline_inference.py with spawn + mp backend, it will throw same error in this issue. While this works fine with api_server(using _AsyncLLMEngine).
I think ipex& xpu is following earlier CUDA implementation (CUDA also have similar issue long time ago, see pytorch/pytorch#40403) and I believe this(using fork as start method) can be fixed in the future.

So how about change to this way:
if user use LLMEngine on xpu, we will not support use mp as distributed backend.(spawn needs main function call, fork are not supported by torch xpu support yet)
if user use _AsyncLLMEngine, and use mp as backend, please use spawn as start method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, thanks!



class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor,
MultiprocessingGPUExecutorAsync):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
19 changes: 16 additions & 3 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
Expand Down Expand Up @@ -439,9 +440,11 @@ def profile_run(self) -> None:
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1

batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len

seq_data, dummy_multi_modal_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
Expand All @@ -465,7 +468,13 @@ def profile_run(self) -> None:
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
self.execute_model(model_input, kv_caches)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.xpu.synchronize()
return

Expand Down Expand Up @@ -537,20 +546,24 @@ def execute_model(
and self.observability_config.collect_model_forward_time):
model_forward_start_time = time.time()

hidden_states = model_executable(
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device))
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states

if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end_time = time.time()

# Compute the logits.
logits = self.model.compute_logits(hidden_states,
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)

# Only perform sampling in the driver worker.
Expand Down
6 changes: 6 additions & 0 deletions vllm/worker/xpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.utils import is_xpu
Expand Down Expand Up @@ -198,3 +199,8 @@ def init_worker_distributed_environment(self) -> None:
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

if parallel_config.pipeline_parallel_size > 1:
# torch-ccl xpu need a collective API warm up
# before calling send/recv API
get_pp_group().all_reduce(torch.zeros(1).xpu())
Loading