From 76fc072b3a486df50408909f4e1fd6a3f63ff4ce Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 01:53:03 +0000 Subject: [PATCH 01/58] Add & warnings --- vllm/worker/tpu_model_runner.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 5003d3b0ca440..d790dbcfef5b2 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -329,10 +329,11 @@ def _prepare_sample( self, seq_group_metadata_list: List[SequenceGroupMetadata], padded_batch_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert len(seq_group_metadata_list) > 0 t = [] p = [] + n = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.sampling_params is not None sampling_params = seq_group_metadata.sampling_params @@ -340,13 +341,25 @@ def _prepare_sample( t.append(sampling_params.temperature if sampling_params.temperature >= 1e-5 else 1e-5) p.append(sampling_params.top_p) + n.append(sampling_params.n) num_paddings = padded_batch_size - len(seq_group_metadata_list) t += [1.0] * num_paddings p += [1.0] * num_paddings + n += [1] * num_paddings + + if any(top_p != 1 for top_p in p): + raise NotImplementedError( + "Top-p sampling is currently not supported by the TPU " + "backend due to performance issues.") + if any(num_samples != 1 for num_samples in n): + raise NotImplementedError( + "Parallel sampling (n > 1) is currently not supported by the " + "TPU backend due to performance issues.") t = torch.tensor(t, dtype=torch.float32, device=self.device) p = torch.tensor(p, dtype=torch.float32, device=self.device) - return t, p + n = torch.tensor(n, dtype=torch.int32, device=self.device) + return t, p, n def prepare_inputs( self, @@ -429,6 +442,7 @@ def forward( input_lens: torch.Tensor, t: torch.Tensor, p: torch.Tensor, + n: torch.Tensor, ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. From 27a5ad8fb1ca3a204a14cd9247c9b090ca8fd68d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 09:00:48 +0000 Subject: [PATCH 02/58] Add in dummy_run --- vllm/worker/tpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index d790dbcfef5b2..59739b7121595 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -137,10 +137,11 @@ def _dummy_run( ) t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + n = torch.ones((batch_size, ), dtype=torch.int32, device=self.device) # Dummy run. self.model(token_ids, position_ids, kv_caches, attn_metadata, - input_lens, t, p) + input_lens, t, p, n) def warmup_model( self, From 5ab6f65f2f36c5654c2115d257628c524e1a0f79 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 09:01:00 +0000 Subject: [PATCH 03/58] Add is_driver_worker --- vllm/worker/tpu_worker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 04576015dadbd..0d7b3fb7ec6f0 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -34,6 +34,7 @@ def __init__( local_rank: int, rank: int, distributed_init_method: str, + is_driver_worker: bool, ) -> None: self.model_config = model_config self.parallel_config = parallel_config @@ -45,6 +46,7 @@ def __init__( self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker assert self.device_config.device_type == "tpu" if self.cache_config.cache_dtype == "auto": From c4e79a04f6314683c7d373653c8f7a3466dfab88 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 09:02:54 +0000 Subject: [PATCH 04/58] Make TPUExecutor similar to GPUExecutor --- vllm/executor/tpu_executor.py | 58 ++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 5ed00e1374100..7fe5349c987ad 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -1,4 +1,4 @@ -from typing import List, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import torch @@ -26,29 +26,45 @@ def _init_executor(self) -> None: self.model_config.dtype = torch.bfloat16 # Instantiate the worker and load the model to the device. - self._init_worker() - - def _init_worker(self): - from vllm.worker.tpu_worker import TPUWorker + self.driver_worker = self._create_worker() + self.driver_worker.init_device() + self.driver_worker.load_model() - assert self.parallel_config.world_size == 1, ( - "TPUExecutor currently only supports a single TPU chip.") - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = TPUWorker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - self.cache_config, - self.load_config, - self.vision_language_config, - local_rank=0, - rank=0, + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None, + ) -> Dict[str, Any]: + """Return worker init args for a given rank.""" + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=local_rank, + rank=rank, distributed_init_method=distributed_init_method, + vision_language_config=self.vision_language_config, + is_driver_worker=rank == 0, ) - self.driver_worker.init_device() - self.driver_worker.load_model() + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None, + ): + from vllm.worker.tpu_worker import TPUWorker + + worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank, + distributed_init_method)) + return worker def initialize_cache( self, From ff8199332dcb0473d9ba092192d8cc64970ad9e7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 09:40:18 +0000 Subject: [PATCH 05/58] Add multiprocessing-based TPU executor --- vllm/executor/multiproc_tpu_executor.py | 182 ++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 vllm/executor/multiproc_tpu_executor.py diff --git a/vllm/executor/multiproc_tpu_executor.py b/vllm/executor/multiproc_tpu_executor.py new file mode 100644 index 0000000000000..c24d3da6209f0 --- /dev/null +++ b/vllm/executor/multiproc_tpu_executor.py @@ -0,0 +1,182 @@ +import asyncio +import os +from functools import partial +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union + +from vllm import envs +from vllm.executor.tpu_executor import TPUExecutor +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) +from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import (get_distributed_init_method, get_open_port, + get_vllm_instance_id, make_async) + +logger = init_logger(__name__) + + +class MultiprocessingTPUExecutor(TPUExecutor): + """Python multiprocessing-based multi-chip TPU executor""" + + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + # Updated by implementations that require additional args to be passed + # to the _run_workers execute_model call + self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} + + super().__init__(*args, **kwargs) + + def _init_executor(self) -> None: + # Create the parallel TPU workers. + world_size = self.parallel_config.tensor_parallel_size + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + + # Disable torch async compiling which won't work with daemonic processes + os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + + # TODO(woosuk) + # assert world_size <= cuda_device_count_stateless(), ( + # "please set tensor_parallel_size to less than max local gpu count") + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + distributed_init_method = get_distributed_init_method( + "127.0.0.1", get_open_port()) + + if world_size == 1: + self.workers = [] + self.worker_monitor = None + else: + result_handler = ResultHandler() + self.workers = [ + ProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + distributed_init_method=distributed_init_method, + )) for rank in range(1, world_size) + ] + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + self.driver_worker = self._create_worker( + distributed_init_method=distributed_init_method) + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_model( + execute_model_req=execute_model_req) + + def _run_workers( + self, + method: str, + *args, + async_run_remote_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if async_run_remote_workers_only: + # Just return futures + return worker_outputs + + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*args, **kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if self.worker_monitor is not None and not self.worker_monitor.is_alive( + ): + raise RuntimeError("Worker processes are not running") + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + num_blocks = self._run_workers("determine_num_available_blocks", ) + num_tpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + return num_tpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_remote_workers_only=True, + **self.extra_execute_model_run_workers_kwargs) + + # Only the driver worker returns the sampling results. + return self._driver_execute_model(execute_model_req) + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) From 16e80b2dce8df2132913209636ccf7753b10a0d1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 10:24:26 +0000 Subject: [PATCH 06/58] Use TPU to initialize Ray cluster --- vllm/executor/ray_utils.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 495fddd175dd4..8141332edba4f 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -3,7 +3,7 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import get_ip, is_hip, is_xpu +from vllm.utils import get_ip, is_hip, is_xpu, is_tpu from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -82,6 +82,7 @@ def initialize_ray_cluster( # Placement group is already set. return + device_str = "GPU" if not is_tpu() else "TPU" # Create placement group for worker processes current_placement_group = ray.util.get_current_placement_group() if current_placement_group: @@ -90,24 +91,27 @@ def initialize_ray_cluster( # Verify that we can use the placement group. gpu_bundles = 0 for bundle in bundles: - bundle_gpus = bundle.get("GPU", 0) + bundle_gpus = bundle.get(device_str, 0) if bundle_gpus > 1: raise ValueError( - "Placement group bundle cannot have more than 1 GPU.") + "Placement group bundle cannot have more than 1 " + f"{device_str}.") if bundle_gpus: gpu_bundles += 1 if parallel_config.world_size > gpu_bundles: raise ValueError( - "The number of required GPUs exceeds the total number of " - "available GPUs in the placement group.") + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group.") else: - num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0) + num_gpus_in_cluster = ray.cluster_resources().get(device_str, 0) if parallel_config.world_size > num_gpus_in_cluster: raise ValueError( - "The number of required GPUs exceeds the total number of " - "available GPUs in the cluster.") + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group.") # Create a new placement group - placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size) + placement_group_specs = ([{ + device_str: 1 + }] * parallel_config.world_size) current_placement_group = ray.util.placement_group( placement_group_specs) # Wait until PG is ready - this will block until all From 05884ce4699de5372a5ad0aaad5610d5d6524598 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 10:25:03 +0000 Subject: [PATCH 07/58] Add pjrt proc init --- vllm/worker/tpu_worker.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 0d7b3fb7ec6f0..07aa54b5e1612 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -4,6 +4,7 @@ import torch import torch_xla.core.xla_model as xm import torch_xla.runtime as xr +from torch_xla._internal import pjrt import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, @@ -62,6 +63,10 @@ def __init__( def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" + if self.parallel_config.world_size > 1: + pjrt.initialize_multiprocess(self.local_rank, + self.parallel_config.world_size) + self.device = xm.xla_device() self.device_config.device = self.device torch.set_grad_enabled(False) From 20d23eb4a2017c6a96be7fbac31d6b92ef95de4f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 10:27:19 +0000 Subject: [PATCH 08/58] Add Ray TPU executor --- vllm/executor/ray_tpu_executor.py | 313 ++++++++++++++++++++++++++++++ 1 file changed, 313 insertions(+) create mode 100644 vllm/executor/ray_tpu_executor.py diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py new file mode 100644 index 0000000000000..7d91136de86a7 --- /dev/null +++ b/vllm/executor/ray_tpu_executor.py @@ -0,0 +1,313 @@ +import asyncio +import os +import pickle +from collections import defaultdict +from itertools import islice, repeat +from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, Union + +import vllm.envs as envs +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.tpu_executor import TPUExecutor +from vllm.executor.ray_utils import RayWorkerWrapper, ray +from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + get_vllm_instance_id, make_async) + +if ray is not None: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + + +class RayTPUExecutor(TPUExecutor): + + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + # Updated by implementations that require additional args to be passed + # to the _run_workers execute_model call + self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} + + super().__init__(*args, **kwargs) + + def _init_executor(self) -> None: + assert self.parallel_config.distributed_executor_backend == "ray" + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + # Create the parallel TPU workers. + self._init_workers_ray(placement_group) + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Create the workers. + driver_ip = get_ip() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("TPU", 0): + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + + assert self.speculative_config is None + worker_module_name = "vllm.worker.tpu_worker" + worker_class_name = "TPUWorker" + + worker = ray.remote( + num_cpus=0, + resources={"TPU": 1}, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + else: + # Else, added to the list of workers. + self.workers.append(worker) + + if self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any TPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "TPU node.") + + # Get the set of TPU IDs used on each node. + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) + + node_workers = defaultdict(list) + for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + + VLLM_INSTANCE_ID = get_vllm_instance_id() + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + }, ) for _ in worker_node_and_gpu_ids] + self._run_workers("update_environment_variables", + all_args=all_args_to_update_environment_variables) + + if len(node_workers) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_method("execute_model", + execute_model_req) + + def _run_workers( + self, + method: str, + *args, + async_run_remote_workers_only: bool = False, + all_args: Optional[List[Tuple[Any, ...]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + use_dummy_driver: bool = False, + max_concurrent_workers: Optional[int] = None, + use_ray_compiled_dag: bool = False, + **kwargs, + ) -> Any: + """Runs the given method on all workers. Can be used in the following + ways: + + - async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than blocking + on the results. + - args/kwargs: All workers share the same args/kwargs + - all_args/all_kwargs: args/kwargs for each worker are specified + individually + """ + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + count = len(self.workers) + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, 1, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, 1, None) + + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, + **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ] + + if async_run_remote_workers_only: + # Just return futures + return ray_worker_outputs + + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) + else: + assert self.driver_dummy_worker is not None + driver_worker_output = ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return [driver_worker_output] + ray_worker_outputs + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + num_blocks = self._run_workers("determine_num_available_blocks", ) + num_tpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + return num_tpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_remote_workers_only=True, + **self.extra_execute_model_run_workers_kwargs) + + # Only the driver worker returns the sampling results. + return self._driver_execute_model(execute_model_req) + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) + + +class RayGPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_method = make_async(self.driver_worker.execute_method) + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + # Start model execution loop running in the parallel workers + self.parallel_worker_tasks = asyncio.create_task( + self._start_worker_execution_loop()) + + # Only the driver worker returns the sampling results. + return await self._driver_execute_model_async(execute_model_req) + + async def stop_remote_worker_execution_loop_async(self) -> None: + if self.parallel_worker_tasks is None: + return + + await self._driver_execute_model_async() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_method("execute_model", + execute_model_req) + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.workers + ] + return await asyncio.gather(*coros) From 5d4df21a718cf9b8fefbe4ef9f948c9c218b66dd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 10:27:38 +0000 Subject: [PATCH 09/58] Use Ray TPU executor for tp --- vllm/engine/llm_engine.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f7eae257fdd16..029cd44a2d4e6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -355,8 +355,14 @@ def from_engine_args( from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor elif engine_config.device_config.device_type == "tpu": - from vllm.executor.tpu_executor import TPUExecutor - executor_class = TPUExecutor + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_tpu_executor import (RayTPUExecutor) + executor_class = RayTPUExecutor + else: + assert distributed_executor_backend is None + from vllm.executor.tpu_executor import TPUExecutor + executor_class = TPUExecutor elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor From 6b2c76c5d0af36148455564da5c59cf144eb2ec7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 19:42:35 +0000 Subject: [PATCH 10/58] Minor --- vllm/executor/ray_tpu_executor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 7d91136de86a7..f6b402472ff27 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -247,8 +247,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks=num_cpu_blocks) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", From d91446b8f21132d9918c4e6232e26f9ab4af0115 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 19:44:54 +0000 Subject: [PATCH 11/58] Fix TPUWorker.execute_model --- vllm/worker/tpu_worker.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 07aa54b5e1612..d583198579f6d 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -182,16 +182,13 @@ def get_cache_block_size_bytes(self) -> int: def execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> List[SamplerOutput]: - if execute_model_req is None: - return [] - - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - num_seq_groups = len(seq_group_metadata_list) - if num_seq_groups == 0: + if not self.is_driver_worker: + self._execute_model_non_driver() return [] + assert execute_model_req is not None # Currently, TPUWorker does not support swapping. # TODO(woosuk): Support block copying. assert len(execute_model_req.blocks_to_swap_in) == 0, ( @@ -200,6 +197,16 @@ def execute_model( "Swapping is not supported for the TPU backend.") assert len(execute_model_req.blocks_to_copy) == 0 + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + assert len(seq_group_metadata_list) > 0 output = self.model_runner.execute_model(seq_group_metadata_list, self.tpu_cache) return [output] + + def start_worker_execution_loop(self) -> None: + while self._execute_model_non_driver(): + pass + + def _execute_model_non_driver(self) -> bool: + self.model_runner.execute_model(None, self.tpu_cache) + return True From ab1595d8533732a5602ef4267023b2983058e699 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 20:36:46 +0000 Subject: [PATCH 12/58] Add is_driver_worker & input broadcast --- vllm/worker/tpu_model_runner.py | 14 +++++++++++++- vllm/worker/tpu_worker.py | 5 ++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 59739b7121595..68ee568c01916 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -9,6 +9,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -33,6 +34,7 @@ def __init__( cache_config: CacheConfig, load_config: LoadConfig, vision_language_config: Optional[VisionLanguageConfig] = None, + is_driver_worker: bool = False, ): self.model_config = model_config self.parallel_config = parallel_config @@ -41,6 +43,7 @@ def __init__( self.cache_config = cache_config self.load_config = load_config self.vision_language_config = vision_language_config + self.is_driver_worker = is_driver_worker self.block_size = self.cache_config.block_size self.max_num_blocks_per_seq = (self.model_config.max_model_len // @@ -387,6 +390,8 @@ def _execute_model( inputs = self.prepare_inputs(seq_group_metadata_list) next_token_ids = self.model(inputs[0], inputs[1], kv_caches, *inputs[2:]) + if not self.is_driver_worker: + return [] next_token_ids = next_token_ids.cpu().tolist() i = 0 @@ -409,7 +414,14 @@ def execute_model( seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> SamplerOutput: - assert seq_group_metadata_list is not None + if self.is_driver_worker: + assert seq_group_metadata_list is not None + metadata_dict = {"seq_group_metadata_list": seq_group_metadata_list} + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + seq_group_metadata_list = metadata_dict.pop("seq_group_metadata_list") + if seq_group_metadata_list[0].is_prompt: # NOTE(woosuk): To reduce the compilation time, we only compile the # prefill inputs with batch size 1. Because the scheduler is not diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index d583198579f6d..c931bcf26f12e 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -59,11 +59,14 @@ def __init__( self.model_runner = TPUModelRunner(model_config, parallel_config, scheduler_config, device_config, cache_config, load_config, - vision_language_config) + vision_language_config, + is_driver_worker=is_driver_worker) def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" if self.parallel_config.world_size > 1: + # FIXME(woosuk): local_world_size should be used instead of + # parallel_config.world_size. pjrt.initialize_multiprocess(self.local_rank, self.parallel_config.world_size) From 4b453939287187eba288d945996c31096bb6dffb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 21:20:55 +0000 Subject: [PATCH 13/58] Call xm._init_world_size_ordinal --- vllm/worker/tpu_worker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index c931bcf26f12e..a5a0ef6d12272 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -69,14 +69,17 @@ def init_device(self) -> None: # parallel_config.world_size. pjrt.initialize_multiprocess(self.local_rank, self.parallel_config.world_size) + xm._init_world_size_ordinal() self.device = xm.xla_device() self.device_config.device = self.device torch.set_grad_enabled(False) torch.set_default_dtype(self.model_config.dtype) - # NOTE(woosuk): This is just a hack to initialize the TP group. - # This cannot perform the actual communication ops. + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. init_distributed_environment( world_size=self.parallel_config.world_size, rank=self.rank, From 86451a2e3c6e258ef0cd5266b9b966aa68990ad2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 21:22:05 +0000 Subject: [PATCH 14/58] Bug fix on vocab --- vllm/model_executor/layers/vocab_parallel_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 1a26c5c63fedc..c5b02ab5cdf02 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -310,7 +310,7 @@ def forward(self, input_): output_parallel = F.embedding(masked_input.long(), self.weight) # Mask the output embedding. if self.tp_size > 1: - output_parallel.masked_fill_(input_mask.unsqueeze(1), 0) + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) # Reduce across all the model parallel GPUs. output = tensor_model_parallel_all_reduce(output_parallel) return output From 05392999c92a87252a6f89c387369e1838250234 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 21:22:18 +0000 Subject: [PATCH 15/58] Use all gather for TPU --- vllm/model_executor/layers/logits_processor.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 7eee599473a11..36b66eb64e6fe 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -5,8 +5,9 @@ import torch import torch.nn as nn -from vllm.distributed import tensor_model_parallel_gather +from vllm.distributed import (tensor_model_parallel_gather, tensor_model_parallel_all_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.utils import is_tpu class LogitsProcessor(nn.Module): @@ -34,6 +35,7 @@ def __init__(self, self.logits_as_input = logits_as_input # original vocabulary size (without LoRA). self.org_vocab_size = org_vocab_size or vocab_size + self.use_all_gather = is_tpu() def forward( self, @@ -66,7 +68,12 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: logits += embedding_bias - logits = tensor_model_parallel_gather(logits) + if self.use_all_gather: + # Gather might not be supported for some devices such as TPUs. + # Use all-gather instead. + logits = tensor_model_parallel_all_gather(logits) + else: + logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[:, :self.org_vocab_size] From b35917cedc824b446dff831e54c05c518788e6e6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Jun 2024 21:22:38 +0000 Subject: [PATCH 16/58] Support TPU in GroupCoordinator --- vllm/distributed/parallel_state.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 5188fadbb92a5..34e4b7e6c8e44 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -32,6 +32,10 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.utils import is_tpu + +if is_tpu(): + import torch_xla.core.xla_model as xm @dataclass @@ -99,6 +103,7 @@ class GroupCoordinator: pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator shm_broadcaster: Optional[Any] # shared memory broadcaster + is_tpu: bool def __init__( self, @@ -113,6 +118,7 @@ def __init__( self.local_rank = local_rank self.device_group = None self.cpu_group = None + self.is_tpu = is_tpu() for ranks in group_ranks: device_group = torch.distributed.new_group( @@ -245,6 +251,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: # Bypass the function if we are using only 1 GPU. if self.world_size == 1: return input_ + + # For TPUs, use xm.all_reduce. + if self.is_tpu: + return xm.all_reduce(xm.REDUCE_SUM, input_) + if ca_comm is not None: out = ca_comm.custom_all_reduce(input_) if out is not None: @@ -263,6 +274,11 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: return input_ assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + # For TPUs, use xm.all_gather. + if self.is_tpu: + return xm.all_gather(input_, dim) + if dim < 0: # Convert negative dim to positive. dim += input_.dim() From b9a84bccf6c6c4ccfb02fe29963a1df19680d983 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 25 Jun 2024 00:43:48 +0000 Subject: [PATCH 17/58] Delete multiproc TPU executor --- vllm/executor/multiproc_tpu_executor.py | 182 ------------------------ 1 file changed, 182 deletions(-) delete mode 100644 vllm/executor/multiproc_tpu_executor.py diff --git a/vllm/executor/multiproc_tpu_executor.py b/vllm/executor/multiproc_tpu_executor.py deleted file mode 100644 index c24d3da6209f0..0000000000000 --- a/vllm/executor/multiproc_tpu_executor.py +++ /dev/null @@ -1,182 +0,0 @@ -import asyncio -import os -from functools import partial -from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union - -from vllm import envs -from vllm.executor.tpu_executor import TPUExecutor -from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, - ResultHandler, WorkerMonitor) -from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import (get_distributed_init_method, get_open_port, - get_vllm_instance_id, make_async) - -logger = init_logger(__name__) - - -class MultiprocessingTPUExecutor(TPUExecutor): - """Python multiprocessing-based multi-chip TPU executor""" - - def __init__(self, *args, **kwargs): - # This is non-None when the execute model loop is running - # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. - self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None - # Updated by implementations that require additional args to be passed - # to the _run_workers execute_model call - self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} - - super().__init__(*args, **kwargs) - - def _init_executor(self) -> None: - # Create the parallel TPU workers. - world_size = self.parallel_config.tensor_parallel_size - - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers - os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() - - # Disable torch async compiling which won't work with daemonic processes - os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - - # TODO(woosuk) - # assert world_size <= cuda_device_count_stateless(), ( - # "please set tensor_parallel_size to less than max local gpu count") - - # Multiprocessing-based executor does not support multi-node setting. - # Since it only works for single node, we can use the loopback address - # 127.0.0.1 for communication. - distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) - - if world_size == 1: - self.workers = [] - self.worker_monitor = None - else: - result_handler = ResultHandler() - self.workers = [ - ProcessWorkerWrapper( - result_handler, - partial( - self._create_worker, - rank=rank, - local_rank=rank, - distributed_init_method=distributed_init_method, - )) for rank in range(1, world_size) - ] - - self.worker_monitor = WorkerMonitor(self.workers, result_handler) - result_handler.start() - self.worker_monitor.start() - - self.driver_worker = self._create_worker( - distributed_init_method=distributed_init_method) - self._run_workers("init_device") - self._run_workers("load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers) - - def shutdown(self): - if (worker_monitor := getattr(self, "worker_monitor", - None)) is not None: - worker_monitor.close() - - def _driver_execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - """Run execute_model in the driver worker. - - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - return self.driver_worker.execute_model( - execute_model_req=execute_model_req) - - def _run_workers( - self, - method: str, - *args, - async_run_remote_workers_only: bool = False, - max_concurrent_workers: Optional[int] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers. - - Args: - async_run_remote_workers_only: If True the method will be run only - in the remote workers, not the driver worker. It will also be - run asynchronously and return a list of futures rather than - blocking on the results. - """ - - if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") - - # Start the workers first. - worker_outputs = [ - worker.execute_method(method, *args, **kwargs) - for worker in self.workers - ] - - if async_run_remote_workers_only: - # Just return futures - return worker_outputs - - driver_worker_method = getattr(self.driver_worker, method) - driver_worker_output = driver_worker_method(*args, **kwargs) - - # Get the results of the workers. - return [driver_worker_output - ] + [output.get() for output in worker_outputs] - - def check_health(self) -> None: - """Raises an error if engine is unhealthy.""" - if self.worker_monitor is not None and not self.worker_monitor.is_alive( - ): - raise RuntimeError("Worker processes are not running") - - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - for result in parallel_worker_tasks: - result.get() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - num_blocks = self._run_workers("determine_num_available_blocks", ) - num_tpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - return num_tpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) - - def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - if self.parallel_worker_tasks is None: - self.parallel_worker_tasks = self._run_workers( - "start_worker_execution_loop", - async_run_remote_workers_only=True, - **self.extra_execute_model_run_workers_kwargs) - - # Only the driver worker returns the sampling results. - return self._driver_execute_model(execute_model_req) - - def stop_remote_worker_execution_loop(self) -> None: - if self.parallel_worker_tasks is None: - return - - self._driver_execute_model() - parallel_worker_tasks = self.parallel_worker_tasks - self.parallel_worker_tasks = None - # Ensure that workers exit model loop cleanly - # (this will raise otherwise) - self._wait_for_tasks_completion(parallel_worker_tasks) From c756b763d877acb25680986ed96a29f84194cfa6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 25 Jun 2024 00:45:59 +0000 Subject: [PATCH 18/58] Minor --- vllm/engine/llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 029cd44a2d4e6..7110a87c89205 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -357,7 +357,7 @@ def from_engine_args( elif engine_config.device_config.device_type == "tpu": if distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_tpu_executor import (RayTPUExecutor) + from vllm.executor.ray_tpu_executor import RayTPUExecutor executor_class = RayTPUExecutor else: assert distributed_executor_backend is None From 16e9934ce70f808fe2d896a2e36b26df9c650588 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Jun 2024 19:39:38 +0000 Subject: [PATCH 19/58] [Bugfix][TPU] Fix CPU cache allocation & swapping --- vllm/attention/backends/pallas.py | 6 +----- vllm/worker/tpu_worker.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 121ca9ec45205..a44980468c1d5 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -28,7 +28,6 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (num_kv_heads, num_blocks, block_size, head_size) - @torch.compile(backend="openxla") @staticmethod def swap_blocks( src_kv_cache: Tuple[torch.Tensor, torch.Tensor], @@ -37,11 +36,8 @@ def swap_blocks( ) -> None: src_k_cache, src_v_cache = src_kv_cache dst_k_cache, dst_v_cache = dst_kv_cache - torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True) - torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True) - - device = dst_k_cache.device src_indices, dst_indices = src_to_dst + device = dst_k_cache.device dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device) dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index c85bf6892fb28..e81e7cca0e711 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -156,14 +156,18 @@ def initialize_cache( self.tpu_cache = [] tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( num_gpu_blocks, self.block_size, num_kv_heads, head_size) + cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( + num_cpu_blocks, self.block_size, num_kv_heads, head_size) for _ in range(num_layers): tpu_k_cache = torch.zeros(tpu_cache_shape, dtype=dtype, device=self.device) tpu_v_cache = torch.zeros_like(tpu_k_cache) self.tpu_cache.append((tpu_k_cache, tpu_v_cache)) - cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu") - cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu") + cpu_k_cache = torch.zeros(cpu_cache_shape, + dtype=dtype, + device="cpu") + cpu_v_cache = torch.zeros_like(cpu_k_cache) self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) self._warmup_model() @@ -228,6 +232,7 @@ def cache_swap( for i in range(num_layers): attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i], src_to_dst) + xm.mark_step() if blocks_to_swap_out: # Swap from TPU to CPU. src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device, @@ -235,6 +240,7 @@ def cache_swap( for i in range(num_layers): attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i], src_to_dst) + xm.mark_step() if blocks_to_copy: src_to_dst = _make_src_to_dst(blocks_to_copy, self.device, self.device) From ca6d1d6f4f388afb646caf2c869a7aaa86a3e4ca Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Jun 2024 20:15:13 +0000 Subject: [PATCH 20/58] yapf --- vllm/executor/ray_tpu_executor.py | 11 +++++------ vllm/executor/ray_utils.py | 2 +- vllm/model_executor/layers/logits_processor.py | 3 ++- vllm/worker/tpu_model_runner.py | 18 +++++------------- 4 files changed, 13 insertions(+), 21 deletions(-) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index f6b402472ff27..a5eac493c63a9 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -1,14 +1,14 @@ import asyncio import os -import pickle from collections import defaultdict from itertools import islice, repeat -from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, + Union) import vllm.envs as envs from vllm.executor.executor_base import ExecutorAsyncBase -from vllm.executor.tpu_executor import TPUExecutor from vllm.executor.ray_utils import RayWorkerWrapper, ray +from vllm.executor.tpu_executor import TPUExecutor from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, @@ -197,10 +197,9 @@ def _run_workers( # Start the ray workers first. ray_worker_outputs = [ - worker.execute_method.remote(method, *worker_args, - **worker_kwargs) + worker.execute_method.remote(method, *worker_args, **worker_kwargs) for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ) in zip(self.workers, all_worker_args, all_worker_kwargs) ] if async_run_remote_workers_only: diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 8141332edba4f..415ded12aaaef 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -3,7 +3,7 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import get_ip, is_hip, is_xpu, is_tpu +from vllm.utils import get_ip, is_hip, is_tpu, is_xpu from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 36b66eb64e6fe..1e0f0f7df9739 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -5,7 +5,8 @@ import torch import torch.nn as nn -from vllm.distributed import (tensor_model_parallel_gather, tensor_model_parallel_all_gather) +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_gather) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import is_tpu diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index cc3730b066574..f3a5efbec4d85 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -145,7 +145,6 @@ def _dummy_run( ) t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - n = torch.ones((batch_size, ), dtype=torch.int32, device=self.device) # Dummy run. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 @@ -383,16 +382,6 @@ def _prepare_sample( num_paddings = padded_batch_size - len(t) t += [1.0] * num_paddings p += [1.0] * num_paddings - n += [1] * num_paddings - - if any(top_p != 1 for top_p in p): - raise NotImplementedError( - "Top-p sampling is currently not supported by the TPU " - "backend due to performance issues.") - if any(num_samples != 1 for num_samples in n): - raise NotImplementedError( - "Parallel sampling (n > 1) is currently not supported by the " - "TPU backend due to performance issues.") t = torch.tensor(t, dtype=torch.float32, device=self.device) p = torch.tensor(p, dtype=torch.float32, device=self.device) @@ -462,11 +451,14 @@ def execute_model( if self.is_driver_worker: assert seq_group_metadata_list is not None assert len(seq_group_metadata_list) > 0 - metadata_dict = {"seq_group_metadata_list": seq_group_metadata_list} + metadata_dict = { + "seq_group_metadata_list": seq_group_metadata_list + } broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) - seq_group_metadata_list = metadata_dict.pop("seq_group_metadata_list") + seq_group_metadata_list = metadata_dict.pop( + "seq_group_metadata_list") if seq_group_metadata_list[0].is_prompt: # NOTE(woosuk): To reduce the compilation time, we only compile the From cd4f68d567fb3863594c932f5ae2f3b08a445787 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Jun 2024 20:15:29 +0000 Subject: [PATCH 21/58] Add Ray to TPU dependency --- requirements-tpu.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 22487f5524dd7..c2140fbffec9f 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -4,4 +4,5 @@ # Dependencies for TPU # Currently, the TPU backend uses a nightly version of PyTorch XLA. # You can install the dependencies in Dockerfile.tpu. +ray triton # To avoid import errors From 546987ad3d0c69868ad300812a81c4a388a7ac4c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Jun 2024 20:45:26 +0000 Subject: [PATCH 22/58] Fix --- vllm/worker/tpu_worker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 76098dea7595d..37c722a1ec2cf 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -242,7 +242,6 @@ def cache_swap( for i in range(num_layers): attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i], src_to_dst) - xm.mark_step() if blocks_to_swap_out: # Swap from TPU to CPU. src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device, @@ -250,7 +249,6 @@ def cache_swap( for i in range(num_layers): attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i], src_to_dst) - xm.mark_step() if blocks_to_copy: src_to_dst = _make_src_to_dst(blocks_to_copy, self.device, self.device) From 330be6e43d82abdd42797c12f58e532fc963a0fd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Jun 2024 20:54:51 +0000 Subject: [PATCH 23/58] Fix --- vllm/attention/backends/pallas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index a0b7d53440a6d..5dec11e2eede7 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -28,6 +28,7 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (num_kv_heads, num_blocks, block_size, head_size) + @torch.compile(backend="openxla") @staticmethod def swap_blocks( src_kv_cache: Tuple[torch.Tensor, torch.Tensor], From 8fab9fd441d5721c35d255040c8d8b0631deea70 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 29 Jun 2024 05:45:36 +0000 Subject: [PATCH 24/58] Add use_all_gather to LoRA --- vllm/lora/layers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 2fddfccaf1e4c..ebc94a2136009 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1073,6 +1073,10 @@ def scale(self): def soft_cap(self): return self.base_layer.soft_cap + @property + def use_all_gather(self): + return self.base_layer.use_all_gather + @property def org_vocab_size(self): return self.base_layer.org_vocab_size From c4cbe9f20ee92899af20aa4e58ef73dff2d6af3e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 29 Jun 2024 20:00:08 +0000 Subject: [PATCH 25/58] Fix --- vllm/executor/ray_tpu_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index a5eac493c63a9..7048d47980723 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -270,7 +270,7 @@ def stop_remote_worker_execution_loop(self) -> None: self._wait_for_tasks_completion(parallel_worker_tasks) -class RayGPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase): +class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) From db7adc7f189d968657fe737b60379cae5cf9a2d3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 30 Jun 2024 04:15:33 +0000 Subject: [PATCH 26/58] Add an assert for dim == -1 --- vllm/distributed/parallel_state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 5bed5c9b2882c..2cf4eca013f27 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -313,6 +313,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: # For TPUs, use xm.all_gather. if self.is_tpu: + assert dim == -1, "TPUs only support dim=-1 for all-gather." return xm.all_gather(input_, dim) if dim < 0: From 696790d839b6b4b4eb46feef9f87cac21401d0e5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 30 Jun 2024 04:26:15 +0000 Subject: [PATCH 27/58] is_tpu -> use_xla --- vllm/distributed/parallel_state.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 2cf4eca013f27..8da511255b880 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -129,7 +129,7 @@ class GroupCoordinator: pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator shm_broadcaster: Optional[Any] # shared memory broadcaster - is_tpu: bool + use_xla: bool # Whether to use PyTorch XLA communicator def __init__( self, @@ -144,7 +144,7 @@ def __init__( self.local_rank = local_rank self.device_group = None self.cpu_group = None - self.is_tpu = is_tpu() + self.use_xla = is_tpu() for ranks in group_ranks: device_group = torch.distributed.new_group( @@ -289,7 +289,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return input_ # For TPUs, use xm.all_reduce. - if self.is_tpu: + if self.use_xla: return xm.all_reduce(xm.REDUCE_SUM, input_) if ca_comm is not None: @@ -312,7 +312,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") # For TPUs, use xm.all_gather. - if self.is_tpu: + if self.use_xla: assert dim == -1, "TPUs only support dim=-1 for all-gather." return xm.all_gather(input_, dim) From 28afe568c8ec8f2f8b2b1672b28e20eb9877355b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 2 Jul 2024 00:48:51 +0000 Subject: [PATCH 28/58] yapf --- vllm/distributed/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 8da511255b880..fed7b2b6f08f8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -129,7 +129,7 @@ class GroupCoordinator: pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator shm_broadcaster: Optional[Any] # shared memory broadcaster - use_xla: bool # Whether to use PyTorch XLA communicator + use_xla: bool # Whether to use PyTorch XLA communicator def __init__( self, From 60bf64dc773874716cde522f24725cd5c2665858 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 2 Jul 2024 01:39:44 +0000 Subject: [PATCH 29/58] Add hack in vocab --- vllm/model_executor/layers/vocab_parallel_embedding.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 4650b2c2458d0..61c429ac97d28 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -9,6 +9,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import is_tpu DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -169,6 +170,12 @@ def __init__(self, # Keep the input dimensions. tp_rank = get_tensor_model_parallel_rank() + if is_tpu(): + import torch_xla.core.xla_model as xm + + # FIXME(woosuk): This is a temporary hack. + tp_rank = xm.get_ordinal() + self.tp_size = get_tensor_model_parallel_world_size() self.num_embeddings = num_embeddings self.padding_size = padding_size From cd4842d8e91c917a25da7844520c391e5f16fe7b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 9 Jul 2024 02:29:58 +0000 Subject: [PATCH 30/58] Fix multi-modal support --- vllm/worker/tpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index a106aba13f003..e36b936c8b9b4 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -155,7 +155,7 @@ def _dummy_run( # Dummy run. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 self.model(token_ids, position_ids, kv_caches, attn_metadata, - input_lens, t, p, num_samples) + input_lens, None, t, p, num_samples) def warmup_model( self, @@ -544,6 +544,7 @@ def forward( pass to the model. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. + num_samples: The number of samples to draw for each sequence. """ batch_size, seq_len = token_ids.shape # Calculate the positions to sample from. From 106864df45801bee444f4dccdff35fbef2bddc72 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 18 Jul 2024 08:45:37 +0000 Subject: [PATCH 31/58] Remove unused --- vllm/worker/tpu_model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 17fab7f721775..8a8b412db6731 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -10,7 +10,6 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, MultiModalConfig, ParallelConfig, SchedulerConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata From 223661fdf74b1142f0f3bfac4c585be98d702111 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 18 Jul 2024 08:52:28 +0000 Subject: [PATCH 32/58] Minor --- vllm/worker/tpu_worker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 95dec2c4473e3..04d49e0bb1dce 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -210,8 +210,7 @@ def get_cache_block_size_bytes(self) -> int: @property def do_metadata_broadcast(self) -> bool: - # TODO(woosuk): Support TP. - return False + return self.parallel_config.tensor_parallel_size > 1 @property def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: From ab7cccfa051d447e5488263c6b3092f7f5f4a0c1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Jul 2024 08:55:39 +0000 Subject: [PATCH 33/58] Fix comm error --- vllm/attention/backends/pallas.py | 4 ++-- vllm/worker/tpu_model_runner.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index b83a83bb177d4..c53a2f91b89d7 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -55,8 +55,8 @@ class PallasMetadata(AttentionMetadata): # Currently, input sequences can only contain all prefills # or all decoding. - block_tables: Optional[torch.Tensor] - context_lens: Optional[torch.Tensor] + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None @property def prefill_metadata(self) -> Optional["PallasMetadata"]: diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 8a8b412db6731..a953192069fe6 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -45,6 +45,7 @@ class ModelInputForTPU(ModelRunnerInputBase): num_samples: int best_of: List[int] seq_groups: List[List[int]] + virtual_engine: int = 0 def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -55,6 +56,9 @@ def as_broadcastable_tensor_dict( "t": self.t, "p": self.p, "num_samples": self.num_samples, + "best_of": self.best_of, + "seq_groups": self.seq_groups, + "virtual_engine": self.virtual_engine, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -466,7 +470,7 @@ def make_model_input_from_broadcasted_tensor_dict( def execute_model( self, model_input: ModelInputForTPU, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + kv_caches: Optional[List[Any]], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> List[SamplerOutput]: From 4e0c90a023859e94bf5f00814215d60f44bfc4cb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Jul 2024 09:13:34 +0000 Subject: [PATCH 34/58] Use custom inference_mode --- vllm/utils.py | 25 +++++++++++++++++++++++++ vllm/worker/worker_base.py | 4 ++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 9e222772eb5b9..b68d918390a67 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -897,6 +897,31 @@ def error_on_invalid_device_count_status(): "CUDA_VISIBLE_DEVICES to the GPUs you want to use.") +class inference_mode(torch.inference_mode): + + def __init__(self, mode: bool = True) -> None: + self.inference_mode = not is_tpu() + if self.inference_mode: + super().__init__(mode) + else: + # No grad. + self.prev = False + self.mode = mode + + def __enter__(self) -> None: + if self.inference_mode: + super().__enter__() + else: + self.prev = torch.is_grad_enabled() + torch.set_grad_enabled(False) + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + if self.inference_mode: + super().__init__(exc_type, exc_value, traceback) + else: + torch.set_grad_enabled(self.prev) + + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # all the related functions work on real physical device ids. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 8e5c0ededba15..436c028689a09 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SamplerOutput) -from vllm.utils import (enable_trace_function_call_for_thread, +from vllm.utils import (enable_trace_function_call_for_thread, inference_mode, update_environment_variables) from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -53,7 +53,7 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise NotImplementedError - @torch.inference_mode() + @inference_mode() def start_worker_execution_loop(self) -> None: """Execute model loop in parallel worker. From a2358ed9243020fcfe5ce5f75f68d1cac9077c31 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Jul 2024 09:35:28 +0000 Subject: [PATCH 35/58] Remove hack in vocab embedding --- vllm/model_executor/layers/vocab_parallel_embedding.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index a70f7eb0ef251..74aeb964274b0 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -12,7 +12,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import is_tpu DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -177,12 +176,6 @@ def __init__(self, # Keep the input dimensions. tp_rank = get_tensor_model_parallel_rank() - if is_tpu(): - import torch_xla.core.xla_model as xm - - # FIXME(woosuk): This is a temporary hack. - tp_rank = xm.get_ordinal() - self.tp_size = get_tensor_model_parallel_world_size() self.num_embeddings = num_embeddings self.padding_size = padding_size From ac21351f2f98cd4f211498f352a5f5e64ab17c1f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Jul 2024 09:35:38 +0000 Subject: [PATCH 36/58] Use patch --- vllm/worker/tpu_model_runner.py | 35 +++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index a953192069fe6..1ec470923ff9e 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,6 +1,7 @@ import time from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from unittest.mock import patch import numpy as np import torch @@ -117,16 +118,30 @@ def __init__( def load_model(self) -> None: self.device = self.device_config.device - model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - parallel_config=self.parallel_config, - cache_config=self.cache_config, - scheduler_config=self.scheduler_config, - multimodal_config=self.multimodal_config, - lora_config=None, - ) + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the weights for vocab embedding. + xm_tp_rank = xm.get_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + parallel_config=self.parallel_config, + cache_config=self.cache_config, + scheduler_config=self.scheduler_config, + multimodal_config=self.multimodal_config, + lora_config=None, + ) model = model.eval() xm.wait_device_ops() From ba76d9ee13c085449075250cf8c411b2c2067217 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Jul 2024 09:46:37 +0000 Subject: [PATCH 37/58] Update inference_mode --- vllm/utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index b68d918390a67..92f9644782365 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -898,25 +898,31 @@ def error_on_invalid_device_count_status(): class inference_mode(torch.inference_mode): + """A device-agnostic wrapper of `torch.inference_mode`. + + This wrapper is recommended because some hardware backends such as TPU + do not support `torch.inference_mode`. In such as case, this class falls + back to `torch.no_grad`. + """ def __init__(self, mode: bool = True) -> None: - self.inference_mode = not is_tpu() - if self.inference_mode: + self.use_inference_mode = not is_tpu() + if self.use_inference_mode: super().__init__(mode) else: - # No grad. + # Fall back to torch.no_grad(). self.prev = False self.mode = mode def __enter__(self) -> None: - if self.inference_mode: + if self.use_inference_mode: super().__enter__() else: self.prev = torch.is_grad_enabled() torch.set_grad_enabled(False) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - if self.inference_mode: + if self.use_inference_mode: super().__init__(exc_type, exc_value, traceback) else: torch.set_grad_enabled(self.prev) From 452c3217d4dc726631ea9d9b4338fd27b25d3af3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Jul 2024 09:52:52 +0000 Subject: [PATCH 38/58] use_all_gather -> use_gather --- vllm/lora/layers.py | 4 ++-- vllm/model_executor/layers/logits_processor.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index edf3ba3d0094b..87de285a373a2 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1068,8 +1068,8 @@ def soft_cap(self): return self.base_layer.soft_cap @property - def use_all_gather(self): - return self.base_layer.use_all_gather + def use_gather(self): + return self.base_layer.use_gather @property def org_vocab_size(self): diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 0b5828e2dfa51..03567f5323304 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -41,7 +41,8 @@ def __init__(self, self.org_vocab_size = org_vocab_size or vocab_size # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap - self.use_all_gather = is_tpu() + # Whether to use gather or all-gather to gather the logits. + self.use_gather = not is_tpu() def forward( self, @@ -79,12 +80,12 @@ def _get_logits(self, hidden_states: torch.Tensor, logits = lm_head.linear_method.apply(lm_head, hidden_states, bias=embedding_bias) - if self.use_all_gather: - # Gather might not be supported for some devices such as TPUs. + if self.use_gather: + logits = tensor_model_parallel_gather(logits) + else: + # Gather is not supported for some devices such as TPUs. # Use all-gather instead. logits = tensor_model_parallel_all_gather(logits) - else: - logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[:, :self.org_vocab_size] From dcb63b7a517234499d2407b0d9a2c688f511ffa0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Jul 2024 10:26:41 +0000 Subject: [PATCH 39/58] Fix patch --- vllm/worker/tpu_model_runner.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 1ec470923ff9e..85f54f0f825d7 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -126,11 +126,10 @@ def load_model(self) -> None: # rank-agnostic. However, it matters for all-gather as the ranks # determine the order of concatenating the output tensors. # As a workaround, we use the xm's rank assignment only when loading - # the weights for vocab embedding. + # the weights. xm_tp_rank = xm.get_ordinal() with patch( - "vllm.model_executor.layers.vocab_parallel_embedding." - "get_tensor_model_parallel_rank", + "vllm.distributed.parallel_state.get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model( model_config=self.model_config, From 825cc44f376f3646b12f748df4e619ecace8f041 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Jul 2024 17:43:31 +0000 Subject: [PATCH 40/58] Fix typo --- vllm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index 92f9644782365..fca0db2431612 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -923,7 +923,7 @@ def __enter__(self) -> None: def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self.use_inference_mode: - super().__init__(exc_type, exc_value, traceback) + super().__exit__(exc_type, exc_value, traceback) else: torch.set_grad_enabled(self.prev) From 973028809eec864c7d1accd2e2cdb4d6893e3be3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Jul 2024 02:06:43 +0000 Subject: [PATCH 41/58] Remove inference_mode --- vllm/utils.py | 31 ------------------------------- vllm/worker/worker_base.py | 2 +- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index fca0db2431612..9e222772eb5b9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -897,37 +897,6 @@ def error_on_invalid_device_count_status(): "CUDA_VISIBLE_DEVICES to the GPUs you want to use.") -class inference_mode(torch.inference_mode): - """A device-agnostic wrapper of `torch.inference_mode`. - - This wrapper is recommended because some hardware backends such as TPU - do not support `torch.inference_mode`. In such as case, this class falls - back to `torch.no_grad`. - """ - - def __init__(self, mode: bool = True) -> None: - self.use_inference_mode = not is_tpu() - if self.use_inference_mode: - super().__init__(mode) - else: - # Fall back to torch.no_grad(). - self.prev = False - self.mode = mode - - def __enter__(self) -> None: - if self.use_inference_mode: - super().__enter__() - else: - self.prev = torch.is_grad_enabled() - torch.set_grad_enabled(False) - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - if self.use_inference_mode: - super().__exit__(exc_type, exc_value, traceback) - else: - torch.set_grad_enabled(self.prev) - - # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # all the related functions work on real physical device ids. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 4487178228a53..03e3857e23c4b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -12,7 +12,7 @@ from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SamplerOutput) -from vllm.utils import (enable_trace_function_call_for_thread, inference_mode, +from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase From 631b08b4b1b9c09fe241ebacc120047b19fd6aec Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 23 Jul 2024 06:58:20 +0000 Subject: [PATCH 42/58] Add no_grad --- vllm/worker/tpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 85f54f0f825d7..7f401f6f39d48 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -481,6 +481,7 @@ def make_model_input_from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend) return model_input + @torch.no_grad() def execute_model( self, model_input: ModelInputForTPU, From af3a259fae778cd1511c36d438a00bfad8013db4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 01:48:22 +0000 Subject: [PATCH 43/58] [TPU] Support collective communications in XLA devices --- vllm/distributed/parallel_state.py | 17 +++++++++++++++++ vllm/lora/layers.py | 4 ++++ vllm/model_executor/layers/logits_processor.py | 16 ++++++++++++++-- 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e9c6fc3a255e4..58a388be03a11 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -34,6 +34,10 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.utils import is_tpu + +if is_tpu(): + import torch_xla.core.xla_model as xm @dataclass @@ -125,6 +129,7 @@ class GroupCoordinator: pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator mq_broadcaster: Optional[Any] # shared memory broadcaster + use_xla: bool # Whether to use PyTorch XLA communicator def __init__( self, @@ -140,6 +145,7 @@ def __init__( self.local_rank = local_rank self.device_group = None self.cpu_group = None + self.use_xla = is_tpu() for ranks in group_ranks: device_group = torch.distributed.new_group( @@ -289,6 +295,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: # Bypass the function if we are using only 1 GPU. if self.world_size == 1: return input_ + + # For TPUs, use xm.all_reduce. + if self.use_xla: + return xm.all_reduce(xm.REDUCE_SUM, input_) + if ca_comm is not None: out = ca_comm.custom_all_reduce(input_) if out is not None: @@ -307,6 +318,12 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: return input_ assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + # For TPUs, use xm.all_gather. + if self.use_xla: + assert dim == -1, "TPUs only support dim=-1 for all-gather." + return xm.all_gather(input_, dim) + if dim < 0: # Convert negative dim to positive. dim += input_.dim() diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 40de134c0a5ee..87de285a373a2 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1067,6 +1067,10 @@ def scale(self): def soft_cap(self): return self.base_layer.soft_cap + @property + def use_gather(self): + return self.base_layer.use_gather + @property def org_vocab_size(self): return self.base_layer.org_vocab_size diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index f6fcf49ef464b..fdb5fbcb5ec7d 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -5,10 +5,12 @@ import torch import torch.nn as nn -from vllm.distributed import tensor_model_parallel_gather +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_gather) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.utils import is_tpu class LogitsProcessor(nn.Module): @@ -39,6 +41,8 @@ def __init__(self, self.org_vocab_size = org_vocab_size or vocab_size # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap + # Whether to use gather or all-gather to gather the logits. + self.use_gather = not is_tpu() def forward( self, @@ -76,7 +80,15 @@ def _get_logits(self, hidden_states: torch.Tensor, logits = lm_head.linear_method.apply(lm_head, hidden_states, bias=embedding_bias) - logits = tensor_model_parallel_gather(logits) + if self.use_gather: + logits = tensor_model_parallel_gather(logits) + else: + # Gather is not supported for some devices such as TPUs. + # Use all-gather instead. + # NOTE(woosuk): Here, the outputs of every device should not be None + # because XLA requires strict SPMD among all devices. Every device + # should execute the same operations after gathering the logits. + logits = tensor_model_parallel_all_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[:, :self.org_vocab_size] From 0f2abea4009a8725915b4a7594e64a520d8baf2d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 02:25:04 -0700 Subject: [PATCH 44/58] Use current_platform --- vllm/distributed/parallel_state.py | 6 +++--- vllm/model_executor/layers/logits_processor.py | 4 ++-- vllm/platforms/interface.py | 3 +++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 58a388be03a11..c7dca8c520bd9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -34,9 +34,9 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import is_tpu +from vllm.platforms import current_platform -if is_tpu(): +if current_platform.is_xla(): import torch_xla.core.xla_model as xm @@ -145,7 +145,7 @@ def __init__( self.local_rank = local_rank self.device_group = None self.cpu_group = None - self.use_xla = is_tpu() + self.use_xla = current_platform.is_xla() for ranks in group_ranks: device_group = torch.distributed.new_group( diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index fdb5fbcb5ec7d..8448d8cc8f6fe 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.utils import is_tpu +from vllm.platforms import current_platform class LogitsProcessor(nn.Module): @@ -42,7 +42,7 @@ def __init__(self, # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - self.use_gather = not is_tpu() + self.use_gather = not current_platform.is_xla() def forward( self, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 0760f9554fb78..2c153b1e95616 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -23,6 +23,9 @@ def is_rocm(self) -> bool: def is_tpu(self) -> bool: return self._enum == PlatformEnum.TPU + def is_xla(self) -> bool: + return self.is_tpu() + @staticmethod def get_device_capability(device_id: int = 0) -> Tuple[int, int]: raise NotImplementedError From 8ebea7ea0985f76dd0c58fa39bb223c78ff519a7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 13:01:59 -0700 Subject: [PATCH 45/58] is_xla -> is_tpu --- vllm/distributed/parallel_state.py | 4 ++-- vllm/model_executor/layers/logits_processor.py | 2 +- vllm/platforms/interface.py | 3 --- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c7dca8c520bd9..2be65470d1346 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -36,7 +36,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -if current_platform.is_xla(): +if current_platform.is_tpu(): import torch_xla.core.xla_model as xm @@ -145,7 +145,7 @@ def __init__( self.local_rank = local_rank self.device_group = None self.cpu_group = None - self.use_xla = current_platform.is_xla() + self.use_xla = current_platform.is_tpu() for ranks in group_ranks: device_group = torch.distributed.new_group( diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 8448d8cc8f6fe..bd3e7e114204f 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -42,7 +42,7 @@ def __init__(self, # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - self.use_gather = not current_platform.is_xla() + self.use_gather = not current_platform.is_tpu() def forward( self, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 2c153b1e95616..0760f9554fb78 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -23,9 +23,6 @@ def is_rocm(self) -> bool: def is_tpu(self) -> bool: return self._enum == PlatformEnum.TPU - def is_xla(self) -> bool: - return self.is_tpu() - @staticmethod def get_device_capability(device_id: int = 0) -> Tuple[int, int]: raise NotImplementedError From 782b1828cc4ba87a47ea9309a1d4ffec8f6cb90e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 14:54:15 -0700 Subject: [PATCH 46/58] Define TPU communicator --- .../device_communicators/tpu_communicator.py | 35 +++++++++++++++++++ vllm/distributed/parallel_state.py | 35 ++++++++++++------- 2 files changed, 57 insertions(+), 13 deletions(-) create mode 100644 vllm/distributed/device_communicators/tpu_communicator.py diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py new file mode 100644 index 0000000000000..24963fa655287 --- /dev/null +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -0,0 +1,35 @@ +from typing import Union + +import torch +from torch.distributed import ProcessGroup + +from vllm.platforms import current_platform + +if current_platform.is_tpu(): + import torch_xla.core.xla_model as xm + from torch_xla._internal import pjrt + + +class TpuCommunicator: + + def __init__( + self, + group: ProcessGroup, + local_rank: int, + world_size: int, + ): + del group # Unused. + if not current_platform.is_tpu(): + self.disabled = True + return + self.disabled = False + + pjrt.initialize_multiprocess(local_rank, world_size) + xm._init_world_size_ordinal() + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return xm.all_reduce(xm.REDUCE_SUM, x) + + def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "TPUs only support dim=-1 for all-gather." + return xm.all_gather(x, dim=dim) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 2be65470d1346..edc6207dee866 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -34,10 +34,6 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.platforms import current_platform - -if current_platform.is_tpu(): - import torch_xla.core.xla_model as xm @dataclass @@ -129,7 +125,6 @@ class GroupCoordinator: pynccl_comm: Optional[Any] # PyNccl communicator ca_comm: Optional[Any] # Custom allreduce communicator mq_broadcaster: Optional[Any] # shared memory broadcaster - use_xla: bool # Whether to use PyTorch XLA communicator def __init__( self, @@ -138,6 +133,7 @@ def __init__( torch_distributed_backend: Union[str, Backend], use_pynccl: bool, use_custom_allreduce: bool, + use_tpu_communicator: bool, use_message_queue_broadcaster: bool = False, ): @@ -145,7 +141,6 @@ def __init__( self.local_rank = local_rank self.device_group = None self.cpu_group = None - self.use_xla = current_platform.is_tpu() for ranks in group_ranks: device_group = torch.distributed.new_group( @@ -170,6 +165,7 @@ def __init__( self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce + self.use_tpu_communicator = use_tpu_communicator # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( @@ -196,6 +192,16 @@ def __init__( else: self.ca_comm = None + from vllm.distributed.device_communicators.tpu_communicator import ( + TpuCommunicator) + self.tpu_communicator: Optional[TpuCommunicator] + if use_tpu_communicator and self.world_size > 1: + self.tpu_communicator = TpuCommunicator( + group=self.cpu_group, + local_rank=local_rank, + world_size=self.world_size, + ) + from vllm.distributed.device_communicators.shm_broadcast import ( MessageQueue) self.mq_broadcaster: Optional[MessageQueue] = None @@ -296,9 +302,10 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.world_size == 1: return input_ - # For TPUs, use xm.all_reduce. - if self.use_xla: - return xm.all_reduce(xm.REDUCE_SUM, input_) + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_reduce(input_) if ca_comm is not None: out = ca_comm.custom_all_reduce(input_) @@ -319,10 +326,10 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - # For TPUs, use xm.all_gather. - if self.use_xla: - assert dim == -1, "TPUs only support dim=-1 for all-gather." - return xm.all_gather(input_, dim) + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_gather(input_, dim) if dim < 0: # Convert negative dim to positive. @@ -741,6 +748,7 @@ def init_world_group(ranks: List[int], local_rank: int, torch_distributed_backend=backend, use_pynccl=False, use_custom_allreduce=False, + use_tpu_communicator=False, ) @@ -759,6 +767,7 @@ def init_model_parallel_group( torch_distributed_backend=backend, use_pynccl=True, use_custom_allreduce=use_custom_allreduce, + use_tpu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, ) From 80872273ede13e05a917931f898cabafb325a73f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 15:26:23 -0700 Subject: [PATCH 47/58] Fix --- vllm/distributed/device_communicators/tpu_communicator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 24963fa655287..579bd466bda44 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,5 +1,3 @@ -from typing import Union - import torch from torch.distributed import ProcessGroup From f04e1792adbf96c2a385cb3f5efa96f470587dbf Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 15:30:54 -0700 Subject: [PATCH 48/58] Address comments --- .../device_communicators/tpu_communicator.py | 11 ++++------- vllm/distributed/parallel_state.py | 6 +----- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 579bd466bda44..69a9a516f3ebe 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,4 +1,5 @@ import torch +import torch.distributed as dist from torch.distributed import ProcessGroup from vllm.platforms import current_platform @@ -10,18 +11,14 @@ class TpuCommunicator: - def __init__( - self, - group: ProcessGroup, - local_rank: int, - world_size: int, - ): - del group # Unused. + def __init__(self, group: ProcessGroup): if not current_platform.is_tpu(): self.disabled = True return self.disabled = False + local_rank = dist.get_rank(group) + world_size = dist.get_world_size(group) pjrt.initialize_multiprocess(local_rank, world_size) xm._init_world_size_ordinal() diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index edc6207dee866..98c1a1fb64ec3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -196,11 +196,7 @@ def __init__( TpuCommunicator) self.tpu_communicator: Optional[TpuCommunicator] if use_tpu_communicator and self.world_size > 1: - self.tpu_communicator = TpuCommunicator( - group=self.cpu_group, - local_rank=local_rank, - world_size=self.world_size, - ) + self.tpu_communicator = TpuCommunicator(group=self.cpu_group) from vllm.distributed.device_communicators.shm_broadcast import ( MessageQueue) From f493c898c6448c10cbcb81bab22ab65a3af3921d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 22:44:49 +0000 Subject: [PATCH 49/58] Device init --- vllm/worker/tpu_worker.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 04d49e0bb1dce..c88aba7ae08cd 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -5,7 +5,6 @@ import torch_xla.core.xla_model as xm import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401 import torch_xla.runtime as xr -from torch_xla._internal import pjrt import vllm.envs as envs from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, @@ -71,15 +70,6 @@ def __init__( def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" - if self.parallel_config.world_size > 1: - # FIXME(woosuk): local_world_size should be used instead of - # parallel_config.world_size. - pjrt.initialize_multiprocess(self.local_rank, - self.parallel_config.world_size) - xm._init_world_size_ordinal() - - self.device = xm.xla_device() - self.device_config.device = self.device torch.set_grad_enabled(False) torch.set_default_dtype(self.model_config.dtype) @@ -98,6 +88,11 @@ def init_device(self) -> None: self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) + # Device initialization should happen after initializing the distributed + # runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + # Set random seed. set_random_seed(self.model_config.seed) xm.set_rng_state(self.model_config.seed, self.device) From f14b085345c6efd1b93a1791f7377a41c8e7da95 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 22:45:47 +0000 Subject: [PATCH 50/58] Fix patch --- vllm/worker/tpu_model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 7f401f6f39d48..e5bb101fc7df4 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -126,10 +126,11 @@ def load_model(self) -> None: # rank-agnostic. However, it matters for all-gather as the ranks # determine the order of concatenating the output tensors. # As a workaround, we use the xm's rank assignment only when loading - # the weights. + # the embedding weights. xm_tp_rank = xm.get_ordinal() with patch( - "vllm.distributed.parallel_state.get_tensor_model_parallel_rank", + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model( model_config=self.model_config, From f9df97da3ce60231a1755fcf37af7955b4b5a7fd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 23:14:24 +0000 Subject: [PATCH 51/58] 0726 --- Dockerfile.tpu | 2 +- docs/source/getting_started/tpu-installation.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile.tpu b/Dockerfile.tpu index be7dbe63cb237..4fc14d6bd186c 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,4 +1,4 @@ -ARG NIGHTLY_DATE="20240713" +ARG NIGHTLY_DATE="20240726" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index 5e2f514a4a509..d43b7bf460561 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -56,7 +56,7 @@ First, install the dependencies: $ pip uninstall torch torch-xla -y $ # Install PyTorch and PyTorch XLA. - $ export DATE="+20240713" + $ export DATE="+20240726" $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl From 9994742cf012e1a45e1b50a82569e54bc11ecc61 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 23:14:41 +0000 Subject: [PATCH 52/58] xr --- vllm/distributed/device_communicators/tpu_communicator.py | 3 ++- vllm/worker/tpu_model_runner.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 69a9a516f3ebe..16525887cf4eb 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -6,6 +6,7 @@ if current_platform.is_tpu(): import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr from torch_xla._internal import pjrt @@ -20,7 +21,7 @@ def __init__(self, group: ProcessGroup): local_rank = dist.get_rank(group) world_size = dist.get_world_size(group) pjrt.initialize_multiprocess(local_rank, world_size) - xm._init_world_size_ordinal() + xr._init_world_size_ordinal() def all_reduce(self, x: torch.Tensor) -> torch.Tensor: return xm.all_reduce(xm.REDUCE_SUM, x) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index e5bb101fc7df4..5cccc580f0607 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, @@ -127,7 +128,7 @@ def load_model(self) -> None: # determine the order of concatenating the output tensors. # As a workaround, we use the xm's rank assignment only when loading # the embedding weights. - xm_tp_rank = xm.get_ordinal() + xm_tp_rank = xr.global_ordinal() with patch( "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", From e0d3232e5ffcca5cbda602746dc5f666521bbf54 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 27 Jul 2024 02:39:23 +0000 Subject: [PATCH 53/58] Add dynamic=True --- vllm/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 5cccc580f0607..80bfe4d2e0191 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -147,7 +147,7 @@ def load_model(self) -> None: xm.wait_device_ops() model = ModelWrapper(model) - self.model = torch.compile(model, backend="openxla", fullgraph=True) + self.model = torch.compile(model, backend="openxla", fullgraph=True, dynamic=True) def _dummy_run( self, From 2f6f54f2ed6079d00900f1040006c8d6a7b24562 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 27 Jul 2024 04:21:10 +0000 Subject: [PATCH 54/58] Remove import --- vllm/attention/backends/pallas.py | 1 - vllm/worker/tpu_worker.py | 1 - 2 files changed, 2 deletions(-) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index c53a2f91b89d7..2269ac2606e89 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -3,7 +3,6 @@ import torch import torch_xla.experimental.custom_kernel # Required to register custom ops. -import torch_xla.experimental.dynamo_set_buffer_donor from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index c88aba7ae08cd..17fa5c35457c2 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -3,7 +3,6 @@ import torch import torch_xla.core.xla_model as xm -import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401 import torch_xla.runtime as xr import vllm.envs as envs From 8bb115939fcb829efba17b748cb544a861c578d7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 27 Jul 2024 04:21:44 +0000 Subject: [PATCH 55/58] yapf --- vllm/worker/tpu_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 80bfe4d2e0191..2bf4f45370ad4 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -147,7 +147,10 @@ def load_model(self) -> None: xm.wait_device_ops() model = ModelWrapper(model) - self.model = torch.compile(model, backend="openxla", fullgraph=True, dynamic=True) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=True) def _dummy_run( self, From fafda57ee614734347ed8b72c74771356970f55b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 27 Jul 2024 04:45:58 +0000 Subject: [PATCH 56/58] Add comment & doc --- docs/source/getting_started/tpu-installation.rst | 7 +++++++ vllm/worker/tpu_model_runner.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index d43b7bf460561..eb46ac3ed93a9 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -75,6 +75,13 @@ Next, build vLLM from source. This will only take a few seconds: $ VLLM_TARGET_DEVICE="tpu" python setup.py develop +.. note:: + + Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape. + The compilation time may take 20~30 minutes in the first run. + However, the compilation time reduces to ~5 minutes in the second run because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default). + + .. tip:: If you encounter the following error: diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 2bf4f45370ad4..d6cc1d9cf8d22 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -147,6 +147,13 @@ def load_model(self) -> None: xm.wait_device_ops() model = ModelWrapper(model) + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Setting dynamic=True can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and need to be + # re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk. self.model = torch.compile(model, backend="openxla", fullgraph=True, From 79c45d5121eafb6d7ec9044735932584816480e9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 27 Jul 2024 04:46:59 +0000 Subject: [PATCH 57/58] Minor --- docs/source/getting_started/tpu-installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index eb46ac3ed93a9..2e6c522422c22 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -79,7 +79,7 @@ Next, build vLLM from source. This will only take a few seconds: Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape. The compilation time may take 20~30 minutes in the first run. - However, the compilation time reduces to ~5 minutes in the second run because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default). + However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default). .. tip:: From 4f0a23ca9bd0ba13551f101b4bfda47287bd7a4b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 27 Jul 2024 05:00:43 +0000 Subject: [PATCH 58/58] Minor --- vllm/worker/tpu_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index d6cc1d9cf8d22..1692094af8c41 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -150,10 +150,10 @@ def load_model(self) -> None: # NOTE(woosuk): There are two stages of compilation: torch.compile and # XLA compilation. Setting dynamic=True can reduce the torch.compile # overhead by reusing the FX graph for different shapes. - # However, the XLA graph will still require static shapes and need to be - # re-compiled for every different shapes. This overhead is inevitable + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable # in the first run, but can be skipped afterwards as we cache the XLA - # graphs in the disk. + # graphs in the disk (VLLM_XLA_CACHE_PATH). self.model = torch.compile(model, backend="openxla", fullgraph=True,