From 61593f1b884ee5553bded692572b2f6565f711fb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 28 Aug 2024 08:30:05 +0000 Subject: [PATCH 1/2] [TPU] Implement async output processing for TPU --- vllm/config.py | 2 +- vllm/worker/tpu_model_runner.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4e014e43d849a..c1efd05bf2946 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -341,7 +341,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - if device_config.device_type != "cuda": + if device_config.device_type not in {"cuda", "tpu"}: logger.warning( "Async output processing is only supported for CUDA." " Disabling it for other platforms.") diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 01daa64b5a32f..de1f24fa7b52b 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,6 +1,7 @@ import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, Union) from unittest.mock import patch import numpy as np @@ -50,6 +51,7 @@ class ModelInputForTPU(ModelRunnerInputBase): best_of: List[int] seq_groups: List[List[int]] virtual_engine: int = 0 + async_callback: Optional[Callable] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -558,6 +560,8 @@ def _execute_model(*args): model_input.attn_metadata, model_input.input_lens[i:i + 1], model_input.t[i:i + 1], model_input.p[i:i + 1], model_input.num_samples, kv_caches) + if i == 0 and model_input.async_callback is not None: + model_input.async_callback() # Retrieve the outputs to CPU. next_token_ids += output_token_ids.cpu().tolist() start_idx = end_idx @@ -568,6 +572,8 @@ def _execute_model(*args): model_input.attn_metadata, model_input.input_lens, model_input.t, model_input.p, model_input.num_samples, kv_caches) + if model_input.async_callback is not None: + model_input.async_callback() # Retrieve the outputs to CPU. next_token_ids = output_token_ids.cpu().tolist() From c11eccfaad175196840764a49166c1620f256f78 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 30 Aug 2024 02:24:10 +0000 Subject: [PATCH 2/2] Fix --- vllm/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ce1bf5b88b215..7e0b75eceae5b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -347,10 +347,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - if device_config.device_type not in {"cuda", "tpu"}: + if device_config.device_type not in ("cuda", "tpu"): logger.warning( - "Async output processing is only supported for CUDA." - " Disabling it for other platforms.") + "Async output processing is only supported for CUDA or TPU. " + "Disabling it for other platforms.") self.use_async_output_proc = False return