From 5d42800ddd1882564a1edc39ffb3142fd525dbbb Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 24 Apr 2024 08:59:54 +0000 Subject: [PATCH 01/43] Combine prompt inputs --- benchmarks/benchmark_latency.py | 4 +- examples/llava_example.py | 10 +- tests/conftest.py | 18 +- tests/engine/test_skip_tokenizer_init.py | 2 +- tests/test_sequence.py | 10 +- tests/tokenization/test_detokenize.py | 7 +- vllm/engine/async_llm_engine.py | 108 ++++++------ vllm/engine/llm_engine.py | 161 +++++++++++------- vllm/entrypoints/llm.py | 66 ++----- vllm/entrypoints/openai/serving_chat.py | 12 +- vllm/entrypoints/openai/serving_completion.py | 17 +- vllm/inputs.py | 48 ++++++ vllm/outputs.py | 2 +- vllm/sequence.py | 38 +++-- 14 files changed, 295 insertions(+), 208 deletions(-) create mode 100644 vllm/inputs.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 44da3bad8d840..8b376a379c450 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -55,13 +55,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate({"prompt_token_ids": dummy_prompt_token_ids}, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate({"prompt_token_ids": dummy_prompt_token_ids}, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/examples/llava_example.py b/examples/llava_example.py index 3d22b492654bf..31853d7e6ae79 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -25,9 +25,13 @@ def run_llava_pixel_values(): # This should be provided by another online or offline component. images = torch.load("images/stop_sign_pixel_values.pt") - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=images), + }) + for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/tests/conftest.py b/tests/conftest.py index 5c50fc2d1bab6..653bc69b174c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel +from vllm.inputs import PromptInputs from vllm.sequence import MultiModalData from vllm.transformers_utils.tokenizer import get_tokenizer @@ -320,12 +321,17 @@ def generate( ) -> List[Tuple[List[int], str]]: if images is not None: assert len(prompts) == images.shape[0] - req_outputs = self.model.generate( - prompts, - sampling_params=sampling_params, - multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE, - data=images) - if images is not None else None) + + prompt_inputs: List[PromptInputs] = [{ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=images) + if images is not None else None + } for prompt in prompts] + + req_outputs = self.model.generate(prompt_inputs, + sampling_params=sampling_params) outputs = [] for req_output in req_outputs: prompt_str = req_output.prompt diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index baa463a316902..338b208723ba9 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str): with pytest.raises(ValueError) as err: llm.generate("abc", sampling_params) assert "prompts must be None if" in str(err.value) - outputs = llm.generate(prompt_token_ids=[[1, 2, 3]], + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params) assert len(outputs) > 0 completions = outputs[0].outputs diff --git a/tests/test_sequence.py b/tests/test_sequence.py index b16bdc141e57c..655fb388d95c2 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -24,7 +24,15 @@ def create_dummy_prompt( # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) + prompt = Sequence( + int(request_id), + inputs={ + "prompt": prompt_str, + "prompt_token_ids": prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size, + ) seq_group = SequenceGroup( request_id, [prompt], SamplingParams(use_beam_search=use_beam_search, best_of=best_of), diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 9bc9becb2a6f1..5b43578aad1f2 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -123,8 +123,11 @@ def create_sequence(prompt_token_ids=None): prompt_token_ids = prompt_token_ids or [1] return Sequence( seq_id=0, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": None, + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3a2f7db679358..dd31a2c4e26fb 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -11,11 +11,11 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.engine.ray_utils import initialize_ray_cluster, ray +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -230,46 +230,51 @@ async def step_async(self) -> List[RequestOutput]: async def encode_request_async( self, request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = await self.tokenizer.encode_async( + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self._require_tokenizer("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = await tokenizer.encode_async( request_id=request_id, - prompt=prompt, + prompt=inputs["prompt"], lora_request=lora_request) - return prompt_token_ids + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) async def add_request_async( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") if arrival_time is None: arrival_time = time.time() - prompt_token_ids = await self.encode_request_async( + + processed_inputs = await self.encode_request_async( + request_id=request_id, inputs=inputs, lora_request=lora_request) + + return self._add_request( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - - return self.add_request(request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, - arrival_time=arrival_time, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + processed_inputs=processed_inputs, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + ) async def check_health_async(self) -> None: self.model_executor.check_health() @@ -505,22 +510,26 @@ async def run_engine_loop(self): async def add_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncStream: if self.log_requests: - shortened_prompt = prompt - shortened_token_ids = prompt_token_ids - if self.max_log_len is not None: + if isinstance(inputs, str): + shortened_prompt = inputs + shortened_token_ids = None + else: + shortened_prompt = inputs.get("prompt") + shortened_token_ids = inputs.get("prompt_token_ids") + + max_log_len = self.max_log_len + if max_log_len is not None: if shortened_prompt is not None: - shortened_prompt = shortened_prompt[:self.max_log_len] + shortened_prompt = shortened_prompt[:max_log_len] if shortened_token_ids is not None: - shortened_token_ids = shortened_token_ids[:self. - max_log_len] + shortened_token_ids = shortened_token_ids[:max_log_len] + logger.info(f"Received request {request_id}: " f"prompt: {shortened_prompt!r}, " f"sampling_params: {sampling_params}, " @@ -541,39 +550,32 @@ async def add_request( arrival_time = time.time() if self.engine_use_ray: - prompt_token_ids = await ( - self.engine.encode_request_async.remote( # type: ignore - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request)) + processed_inputs = await self.engine.encode_request_async.remote( # type: ignore + request_id=request_id, + inputs=inputs, + lora_request=lora_request) else: - prompt_token_ids = await self.engine.encode_request_async( + processed_inputs = await self.engine.encode_request_async( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, + inputs=inputs, lora_request=lora_request) stream = self._request_tracker.add_request( request_id, - prompt=prompt, + inputs=processed_inputs, sampling_params=sampling_params, - prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, ) return stream async def generate( self, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -582,14 +584,10 @@ async def generate( from the LLMEngine to the caller. Args: - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + inputs: The inputs to the LLM. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data per request. Yields: The output `RequestOutput` objects from the LLMEngine for the @@ -644,12 +642,10 @@ async def generate( try: stream = await self.add_request( request_id, - prompt, + inputs, sampling_params, - prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, ) async for request_output in stream: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 19e58fb1722cf..b981dc84cf164 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -17,12 +17,12 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.ray_utils import initialize_ray_cluster from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceStage) +from vllm.sequence import SamplerOutput, Sequence, SequenceGroup, SequenceStage from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -78,6 +78,7 @@ class LLMEngine: log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection """ + tokenizer: Optional[BaseTokenizerGroup] def __init__( self, @@ -134,9 +135,8 @@ def __init__( self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: - self.tokenizer: BaseTokenizerGroup - self._init_tokenizer() - self.detokenizer = Detokenizer(self.tokenizer) + tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(tokenizer) else: self.detokenizer = None self.tokenizer = None @@ -287,12 +287,23 @@ def __reduce__(self): # the closure used to initialize Ray worker actors raise RuntimeError("LLMEngine should not be pickled!") + def _require_tokenizer(self, fail_msg: Optional[str] = None): + if self.tokenizer is None: + if fail_msg is None: + fail_msg = ("Unable to get tokenizer because " + "skip_tokenizer_init is True") + + raise ValueError(fail_msg) + + return self.tokenizer + def get_tokenizer(self) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(None) + return self._require_tokenizer().get_lora_tokenizer(None) def get_tokenizer_for_seq(self, sequence: Sequence) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + return self._require_tokenizer().get_lora_tokenizer( + sequence.lora_request) def _init_tokenizer(self, **tokenizer_init_kwargs): init_kwargs = dict( @@ -304,9 +315,12 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) + self.tokenizer = get_tokenizer_group( self.parallel_config.tokenizer_pool_config, **init_kwargs) + return self.tokenizer + def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) @@ -315,29 +329,81 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) + def _add_request( + self, + request_id: str, + processed_inputs: LLMInputs, + sampling_params: SamplingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + ) -> None: + max_logprobs = self.get_model_config().max_logprobs + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): + raise ValueError(f"Cannot request more than " + f"{max_logprobs} logprobs.") + + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + eos_token_id = None + if self.tokenizer: + eos_token_id = self.tokenizer.get_lora_tokenizer( + lora_request).eos_token_id + else: + logger.warning("Use None for EOS token id because tokenizer is " + "not initialized") + seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + lora_request) + + # Defensive copy of SamplingParams, which are used by the sampler, + # this doesn't deep-copy LogitsProcessor objects + sampling_params = sampling_params.clone() + # inject the eos token id into the sampling_params to support min_tokens + # processing + sampling_params.eos_token_id = seq.eos_token_id + sampling_params.update_from_generation_config( + self.generation_config_fields) + + # Create the sequence group. + seq_group = SequenceGroup(request_id, [seq], sampling_params, + arrival_time, lora_request) + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + def encode_request( self, - request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + request_id: str, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - return prompt_token_ids + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self._require_tokenizer("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=inputs["prompt"], + lora_request=lora_request) + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) def add_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: """Add a request to the engine's request pool. @@ -347,14 +413,10 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + inputs: The inputs to the LLM. sampling_params: The sampling parameters for text generation. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use the current monotonic time. - multi_modal_data: Multi modal data per request. Details: - Set arrival_time to the current time if it is None. @@ -383,49 +445,20 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") - max_logprobs = self.get_model_config().max_logprobs - if (sampling_params.logprobs - and sampling_params.logprobs > max_logprobs) or ( - sampling_params.prompt_logprobs - and sampling_params.prompt_logprobs > max_logprobs): - raise ValueError(f"Cannot request more than " - f"{max_logprobs} logprobs.") if arrival_time is None: arrival_time = time.time() - prompt_token_ids = self.encode_request( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) - eos_token_id = None - if self.tokenizer: - eos_token_id = self.tokenizer.get_lora_tokenizer( - lora_request).eos_token_id - else: - logger.warning("Use None for EOS token id because tokenizer is " - "not initialized") - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - eos_token_id, lora_request) - - # Defensive copy of SamplingParams, which are used by the sampler, - # this doesn't deep-copy LogitsProcessor objects - sampling_params = sampling_params.clone() - # inject the eos token id into the sampling_params to support min_tokens - # processing - sampling_params.eos_token_id = seq.eos_token_id - sampling_params.update_from_generation_config( - self.generation_config_fields) - - # Create the sequence group. - seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, lora_request, multi_modal_data) + processed_inputs = self.encode_request(request_id=request_id, + inputs=inputs, + lora_request=lora_request) - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + return self._add_request( + request_id=request_id, + processed_inputs=processed_inputs, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + ) def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a request(s) with the given ID. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b022707794a78..b31b28a15fa4a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,15 +1,14 @@ -from typing import List, Optional, Union +from typing import List, Optional, Sequence, Union -import torch from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.inputs import PromptStrictInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter @@ -131,13 +130,11 @@ def set_tokenizer( def generate( self, - prompts: Optional[Union[str, List[str]]] = None, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, - prompt_token_ids: Optional[List[List[int]]] = None, + Sequence[SamplingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -146,42 +143,24 @@ def generate( into a single list and pass it to this method. Args: - prompts: A list of prompts to generate completions for. + inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. - prompt_token_ids: A list of token IDs for the prompts. If None, we - use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data. Returns: A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") - if self.llm_engine.model_config.skip_tokenizer_init \ - and prompts is not None: - raise ValueError("prompts must be None if skip_tokenizer_init " - "is True") - if isinstance(prompts, str): + if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. - prompts = [prompts] - if (prompts is not None and prompt_token_ids is not None - and len(prompts) != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + inputs = [inputs] - if prompts is not None: - num_requests = len(prompts) - else: - assert prompt_token_ids is not None - num_requests = len(prompt_token_ids) + num_requests = len(inputs) if sampling_params is None: # Use default sampling params. @@ -191,43 +170,28 @@ def generate( list) and len(sampling_params) != num_requests: raise ValueError("The lengths of prompts and sampling_params " "must be the same.") - if multi_modal_data: - multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. - for i in range(num_requests): - prompt = prompts[i] if prompts is not None else None - token_ids = None if prompt_token_ids is None else prompt_token_ids[ - i] + for i, request_inputs in enumerate(inputs): self._add_request( - prompt, + request_inputs, sampling_params[i] - if isinstance(sampling_params, list) else sampling_params, - token_ids, + if isinstance(sampling_params, Sequence) else sampling_params, lora_request=lora_request, - # Get ith image while maintaining the batch dim. - multi_modal_data=MultiModalData( - type=multi_modal_data.type, - data=multi_modal_data.data[i].unsqueeze(0)) - if multi_modal_data else None, ) return self._run_engine(use_tqdm) def _add_request( self, - prompt: Optional[str], + inputs: PromptStrictInputs, sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]], lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, - prompt, + inputs, sampling_params, - prompt_token_ids, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + lora_request=lora_request) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. @@ -251,4 +215,4 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # This is necessary because some requests may be finished earlier than # its previous requests. outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs \ No newline at end of file + return outputs diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2ff335eb71073..5eb7ad51b64a3 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -83,9 +83,15 @@ async def create_chat_completion( except ValueError as e: return self.create_error_response(str(e)) - result_generator = self.engine.generate(prompt_text, sampling_params, - request_id, prompt_ids, - lora_request) + result_generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + request_id, + lora_request, + ) # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 211b2e0424c3e..5786170e2f2a5 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -117,12 +117,17 @@ async def create_completion(self, request: CompletionRequest, truncate_prompt_tokens) prompt_ids, prompt_text = prompt_formats - generators.append( - self.engine.generate(prompt_text, - sampling_params, - f"{request_id}-{i}", - prompt_token_ids=prompt_ids, - lora_request=lora_request)) + generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + f"{request_id}-{i}", + lora_request=lora_request, + ) + + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) diff --git a/vllm/inputs.py b/vllm/inputs.py new file mode 100644 index 0000000000000..bd61f959eeb6e --- /dev/null +++ b/vllm/inputs.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING, List, Optional, TypedDict, Union + +if TYPE_CHECKING: + from vllm.sequence import MultiModalData + + +class MultiModalPrompt(TypedDict, total=False): + multi_modal_data: Optional["MultiModalData"] + """Multi modal data.""" + + +class StringPrompt(MultiModalPrompt, TypedDict): + prompt: str + """The prompt string.""" + + +class TokensPrompt(MultiModalPrompt, TypedDict): + prompt_token_ids: List[int] + """The token IDs of the prompt. If None, we use the + tokenizer to convert the prompts to token IDs.""" + + +class StringTokensPrompt(MultiModalPrompt, TypedDict): + """It is assumed that :attr:`prompt` is consistent with + :attr:`prompt_token_ids`. This is currently used in + :class:`AsyncLLMEngine` for logging both the text and token IDs.""" + + prompt: str + """The prompt string.""" + + prompt_token_ids: List[int] + """The token IDs of the prompt. If None, we use the + tokenizer to convert the prompts to token IDs.""" + + +PromptStrictInputs = Union[str, StringPrompt, TokensPrompt] +"""The prompt string. More complex inputs should be represented by +:class:`StringPrompt` or :class:`TokensPrompt`.""" + +PromptInputs = Union[str, StringPrompt, TokensPrompt, StringTokensPrompt] +"""As :const:`PromptStrictInputs` but additionally accepts +:class:`StringTokensPrompt`.""" + + +class LLMInputs(TypedDict): + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional["MultiModalData"] diff --git a/vllm/outputs.py b/vllm/outputs.py index d01be0eb0efd2..78b70dfe107e3 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -74,7 +74,7 @@ class RequestOutput: def __init__( self, request_id: str, - prompt: str, + prompt: Optional[str], prompt_token_ids: List[int], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], diff --git a/vllm/sequence.py b/vllm/sequence.py index b296b37a84f15..3ea3af0f7cba7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union from vllm.block import LogicalTokenBlock +from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams @@ -193,8 +194,7 @@ class Sequence: Args: seq_id: The ID of the sequence. - prompt: The prompt of the sequence. - prompt_token_ids: The token IDs of the prompt. + inputs: The inputs of the sequence. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. lora_request: LoRA request. @@ -203,25 +203,24 @@ class Sequence: def __init__( self, seq_id: int, - prompt: str, - prompt_token_ids: List[int], + inputs: LLMInputs, block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id - self.prompt = prompt + self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request - self.data = SequenceData(prompt_token_ids) + self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] # Initialize the logical token blocks with the prompt token ids. - self._append_tokens_to_blocks(prompt_token_ids) + self._append_tokens_to_blocks(self.prompt_token_ids) self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None @@ -231,6 +230,18 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def prompt(self) -> Optional[str]: + return self.inputs["prompt"] + + @property + def prompt_token_ids(self) -> List[int]: + return self.inputs["prompt_token_ids"] + + @property + def multi_modal_data(self) -> Optional["MultiModalData"]: + return self.inputs["multi_modal_data"] + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -398,7 +409,6 @@ class SequenceGroup: sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. lora_request: LoRA request. - multi_modal_data: Multi modal data associated with the request. """ def __init__( @@ -408,7 +418,6 @@ def __init__( sampling_params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -421,10 +430,9 @@ def __init__( self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() - self.multi_modal_data = multi_modal_data @property - def prompt(self) -> str: + def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).prompt @@ -433,7 +441,13 @@ def prompt(self) -> str: def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return next(iter(self.seqs_dict.values())).data.prompt_token_ids + return next(iter(self.seqs_dict.values())).prompt_token_ids + + @property + def multi_modal_data(self) -> Optional[MultiModalData]: + # All sequences in the group should have the same multi-modal data. + # We use the multi-modal data of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).multi_modal_data @property def lora_int_id(self) -> int: From 5db2c5e03ed0b84cf55c26fd9c60f11e4f1bd4b0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 25 Apr 2024 01:51:08 +0000 Subject: [PATCH 02/43] Fix a bunch of tests --- benchmarks/benchmark_throughput.py | 6 +----- tests/conftest.py | 13 ++++++++----- tests/core/test_block_manager.py | 15 ++++++++++++--- tests/core/utils.py | 15 ++++++++++++--- tests/samplers/test_logits_processor.py | 9 +++------ tests/samplers/test_seeded_generate.py | 6 +----- tests/test_cache_block_hashing.py | 11 +++++++++-- tests/tokenization/test_detokenize.py | 2 +- 8 files changed, 47 insertions(+), 30 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 6bb889d1eceba..ae05c3bf742f4 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -113,11 +113,7 @@ def run_vllm( max_tokens=output_len, ) # FIXME(woosuk): Do not use internal method. - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - sampling_params=sampling_params, - ) + llm._add_request(prompt, sampling_params=sampling_params) start = time.perf_counter() # FIXME(woosuk): Do not use internal method. diff --git a/tests/conftest.py b/tests/conftest.py index 653bc69b174c5..e187a178d10c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -322,12 +322,15 @@ def generate( if images is not None: assert len(prompts) == images.shape[0] + if images is None: + mm_data = None + else: + mm_data = MultiModalData(type=MultiModalData.Type.IMAGE, + data=images) + prompt_inputs: List[PromptInputs] = [{ - "prompt": - prompt, - "multi_modal_data": - MultiModalData(type=MultiModalData.Type.IMAGE, data=images) - if images is not None else None + "prompt": prompt, + "multi_modal_data": mm_data } for prompt in prompts] req_outputs = self.model.generate(prompt_inputs, diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 62984ef4caabb..62da6c4850f65 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -132,8 +132,11 @@ def test_append_slot_cow(): # Allocate prompt to gpu block. There is one slot left in the block. prompt = Sequence(seq_id=1, - prompt="one two three", - prompt_token_ids=[1, 2, 3], + inputs={ + "prompt": "one two three", + "prompt_token_ids": [1, 2, 3], + "multi_modal_data": None + }, block_size=block_size) # Fork the sequence, such that a COW will be required when we append a new @@ -298,7 +301,13 @@ def test_sliding_window_multi_seq(): assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - parent = Sequence(1, "one two three", [0, 1, 2], block_size) + parent = Sequence(seq_id=1, + inputs={ + "prompt": "one two three", + "prompt_token_ids": [0, 1, 2], + "multi_modal_data": None + }, + block_size=block_size) seq_group = SequenceGroup("1", [parent], SamplingParams(), time.time(), None) block_manager.allocate(seq_group) diff --git a/tests/core/utils.py b/tests/core/utils.py index 22c1d3826dff4..fd0edce6524f5 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -21,7 +21,13 @@ def create_dummy_prompt( # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) + prompt = Sequence(int(request_id), + inputs={ + "prompt": prompt_str, + "prompt_token_ids": prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) seq_group = SequenceGroup( request_id, [prompt], SamplingParams(use_beam_search=use_beam_search, best_of=best_of), @@ -48,8 +54,11 @@ def create_seq_group( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 3788e9e9752ff..8c877265e71a0 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -35,26 +35,23 @@ def pick_vllm(token_ids, logits): # test logits_processors when prompt_logprobs is not None vllm_model.model._add_request( - prompt=example_prompts[0], + example_prompts[0], sampling_params=params_with_logprobs, - prompt_token_ids=None, ) # test prompt_logprobs is not None vllm_model.model._add_request( - prompt=example_prompts[1], + example_prompts[1], sampling_params=SamplingParams( prompt_logprobs=3, max_tokens=max_tokens, ), - prompt_token_ids=None, ) # test grouped requests vllm_model.model._add_request( - prompt=example_prompts[2], + example_prompts[2], sampling_params=SamplingParams(max_tokens=max_tokens), - prompt_token_ids=None, ) outputs = vllm_model.model._run_engine(False) diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index 3cd659cef58da..ba8070cd16dbc 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -57,11 +57,7 @@ def test_random_sample_with_seed( sampling_params_seed_1, sampling_params_seed_2, ): - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - sampling_params=params, - ) + llm._add_request(prompt, sampling_params=params) results = llm._run_engine(use_tqdm=False) all_outputs = [[out.token_ids for out in output.outputs] diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 3b257ac062f56..97864af88e40a 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -70,8 +70,15 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, for prompt in prompts: hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - tokenizer.tokenizer.eos_token_id, lora_request) + seq = Sequence(seq_id, + inputs={ + "prompt": prompt, + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + block_size=block_size, + eos_token_id=tokenizer.tokenizer.eos_token_id, + lora_request=lora_request) num_blocks = len(prompt_token_ids) // block_size for idx in range(num_blocks): diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 5b43578aad1f2..1d4c74d6bd8da 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -124,7 +124,7 @@ def create_sequence(prompt_token_ids=None): return Sequence( seq_id=0, inputs={ - "prompt": None, + "prompt": "", "prompt_token_ids": prompt_token_ids, "multi_modal_data": None, }, From 74c5905d96707ef32a7846e32ceb0bfa4f27c1fa Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 25 Apr 2024 05:13:45 +0000 Subject: [PATCH 03/43] Fix LLaVA test --- examples/llava_example.py | 15 +++++++++------ tests/conftest.py | 20 +++++++++++--------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/examples/llava_example.py b/examples/llava_example.py index 31853d7e6ae79..60250c4303fbf 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -23,13 +23,13 @@ def run_llava_pixel_values(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_pixel_values.pt") + image = torch.load("images/stop_sign_pixel_values.pt") outputs = llm.generate({ "prompt": prompt, "multi_modal_data": - MultiModalData(type=MultiModalData.Type.IMAGE, data=images), + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), }) for o in outputs: @@ -50,11 +50,14 @@ def run_llava_image_features(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_image_features.pt") + image = torch.load("images/stop_sign_image_features.pt") - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + }) for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/tests/conftest.py b/tests/conftest.py index e187a178d10c1..b86e67c6c4da7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -322,16 +322,18 @@ def generate( if images is not None: assert len(prompts) == images.shape[0] - if images is None: - mm_data = None - else: - mm_data = MultiModalData(type=MultiModalData.Type.IMAGE, - data=images) + prompt_inputs: List[PromptInputs] = [] + for i, prompt in enumerate(prompts): + image = None if images is None else images[i:i + 1] + mm_data = None if image is None else MultiModalData( + type=MultiModalData.Type.IMAGE, + data=image, + ) - prompt_inputs: List[PromptInputs] = [{ - "prompt": prompt, - "multi_modal_data": mm_data - } for prompt in prompts] + prompt_inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data, + }) req_outputs = self.model.generate(prompt_inputs, sampling_params=sampling_params) From b49aba766d0c349e2bf2a2c186200ecfad751ce7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 25 Apr 2024 05:19:57 +0000 Subject: [PATCH 04/43] Fix `benchmark_latency` test --- benchmarks/benchmark_latency.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 8b376a379c450..8932788cac119 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -2,13 +2,14 @@ import argparse import time from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy as np import torch from tqdm import tqdm from vllm import LLM, SamplingParams +from vllm.inputs import PromptStrictInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -44,7 +45,9 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() + dummy_inputs: List[PromptStrictInputs] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: @@ -55,13 +58,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate({"prompt_token_ids": dummy_prompt_token_ids}, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate({"prompt_token_ids": dummy_prompt_token_ids}, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() From c4f35401a58ab4e400131f2165fcb8474e63b459 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 May 2024 03:42:47 +0000 Subject: [PATCH 05/43] Clarify tokenizer usage --- vllm/engine/async_llm_engine.py | 4 ++-- vllm/engine/llm_engine.py | 29 +++++++++++++++-------------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 1d52e9265ff48..332dfd64ba37d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -239,8 +239,8 @@ async def encode_request_async( inputs = {"prompt": inputs} if "prompt_token_ids" not in inputs: - tokenizer = self._require_tokenizer("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") prompt_token_ids = await tokenizer.encode_async( request_id=request_id, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c9ae99c4a78c..622b6d2819695 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -145,6 +145,8 @@ def __init__( self.decoding_config = decoding_config or DecodingConfig() self.log_stats = log_stats + self.tokenizer: Optional[BaseTokenizerGroup] + if not self.model_config.skip_tokenizer_init: tokenizer = self._init_tokenizer() self.detokenizer = Detokenizer(tokenizer) @@ -301,28 +303,27 @@ def __reduce__(self): # the closure used to initialize Ray worker actors raise RuntimeError("LLMEngine should not be pickled!") - def _require_tokenizer(self, fail_msg: Optional[str] = None): - if self.tokenizer is None: - if fail_msg is None: - fail_msg = ("Unable to get tokenizer because " - "skip_tokenizer_init is True") - - raise ValueError(fail_msg) - - return self.tokenizer - def __del__(self): # Shutdown model executor when engine is garbage collected # Use getattr since __init__ can fail before the field is set if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() + MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because " + "skip_tokenizer_init is True") + + def get_tokenizer_group(self, fail_msg: str = MISSING_TOKENIZER_GROUP_MSG): + if self.tokenizer is None: + raise ValueError(fail_msg) + + return self.tokenizer + def get_tokenizer(self) -> "PreTrainedTokenizer": - return self._require_tokenizer().get_lora_tokenizer(None) + return self.get_tokenizer_group().get_lora_tokenizer(None) def get_tokenizer_for_seq(self, sequence: Sequence) -> "PreTrainedTokenizer": - return self._require_tokenizer().get_lora_tokenizer( + return self.get_tokenizer_group().get_lora_tokenizer( sequence.lora_request) def _init_tokenizer(self, **tokenizer_init_kwargs): @@ -405,8 +406,8 @@ def encode_request( inputs = {"prompt": inputs} if "prompt_token_ids" not in inputs: - tokenizer = self._require_tokenizer("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") prompt_token_ids = tokenizer.encode(request_id=request_id, prompt=inputs["prompt"], From ab8182ce45c0897b22da7fa23244eb00bdddfc92 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 May 2024 03:55:21 +0000 Subject: [PATCH 06/43] Rename `encode_request -> process_model_inputs` --- tests/async_engine/test_async_llm_engine.py | 2 +- vllm/engine/async_llm_engine.py | 17 +++++++++-------- vllm/engine/llm_engine.py | 12 ++++++------ 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index b69cdc0a21409..10a46422887e3 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -25,7 +25,7 @@ async def step_async(self): return [RequestOutput( request_id=self.request_id)] if self.request_id else [] - async def encode_request_async(self, *args, **kwargs): + async def process_model_inputs_async(self, *args, **kwargs): pass def generate(self, request_id): diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 332dfd64ba37d..7d0ee745813d9 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -229,7 +229,7 @@ async def step_async(self) -> List[RequestOutput]: return request_outputs - async def encode_request_async( + async def process_model_inputs_async( self, request_id: str, # pylint: disable=unused-argument inputs: PromptInputs, @@ -267,10 +267,10 @@ async def add_request_async( if arrival_time is None: arrival_time = time.time() - processed_inputs = await self.encode_request_async( + processed_inputs = await self.process_model_inputs_async( request_id=request_id, inputs=inputs, lora_request=lora_request) - return self._add_request( + return self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, sampling_params=sampling_params, @@ -552,12 +552,13 @@ async def add_request( arrival_time = time.time() if self.engine_use_ray: - processed_inputs = await self.engine.encode_request_async.remote( # type: ignore - request_id=request_id, - inputs=inputs, - lora_request=lora_request) + processed_inputs = await self.engine.process_model_inputs_async \ + .remote( # type: ignore + request_id=request_id, + inputs=inputs, + lora_request=lora_request) else: - processed_inputs = await self.engine.encode_request_async( + processed_inputs = await self.engine.process_model_inputs_async( request_id=request_id, inputs=inputs, lora_request=lora_request) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 622b6d2819695..6a937347e0bec 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -350,7 +350,7 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) - def _add_request( + def _add_processed_request( self, request_id: str, processed_inputs: LLMInputs, @@ -396,7 +396,7 @@ def _add_request( # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) - def encode_request( + def process_model_inputs( self, request_id: str, inputs: PromptInputs, @@ -470,11 +470,11 @@ def add_request( if arrival_time is None: arrival_time = time.time() - processed_inputs = self.encode_request(request_id=request_id, - inputs=inputs, - lora_request=lora_request) + processed_inputs = self.process_model_inputs(request_id=request_id, + inputs=inputs, + lora_request=lora_request) - return self._add_request( + return self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, sampling_params=sampling_params, From eac33e1f7d1747dd4147423d45ef3c9d65f14b95 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 May 2024 06:35:09 +0000 Subject: [PATCH 07/43] Support old API in `LLM.generate` --- .buildkite/test-pipeline.yaml | 1 + tests/test_inputs.py | 53 ++++++++++ vllm/entrypoints/llm.py | 183 +++++++++++++++++++++++++++++++++- vllm/inputs.py | 62 +++++++++++- vllm/utils.py | 28 +++++- 5 files changed, 321 insertions(+), 6 deletions(-) create mode 100644 tests/test_inputs.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e49a5650c44ea..60b57ea7a5f78 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -53,6 +53,7 @@ steps: - label: Entrypoints Test commands: + - pytest -v -s test_inputs.py # these tests have to be separated, because each one will allocate all posible GPU memory - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - pytest -v -s entrypoints/test_server_oot_registration.py diff --git a/tests/test_inputs.py b/tests/test_inputs.py new file mode 100644 index 0000000000000..887c7101decda --- /dev/null +++ b/tests/test_inputs.py @@ -0,0 +1,53 @@ +from typing import List + +import pytest + +from vllm.inputs import parse_and_batch_prompt + +STRING_INPUTS = [ + '', + 'foo', + 'foo bar', + 'foo baz bar', + 'foo bar qux baz', +] + +TOKEN_INPUTS = [ + [-1], + [1], + [1, 2], + [1, 3, 4], + [1, 2, 4, 3], +] + +INPUTS_SLICES = [ + slice(None, None, -1), + slice(None, None, 2), + slice(None, None, -2), +] + + +def test_parse_single_batch_empty(): + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([]) + + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([[]]) + + +@pytest.mark.parametrize('string_input', STRING_INPUTS) +def test_parse_single_batch_string_consistent(string_input: str): + assert parse_and_batch_prompt(string_input) \ + == parse_and_batch_prompt([string_input]) + + +@pytest.mark.parametrize('token_input', TOKEN_INPUTS) +def test_parse_single_batch_token_consistent(token_input: List[int]): + assert parse_and_batch_prompt(token_input) \ + == parse_and_batch_prompt([token_input]) + + +@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES) +def test_parse_single_batch_string_slice(inputs_slice: slice): + assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ + == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b31b28a15fa4a..90b0423698ac6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,16 +1,18 @@ -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union, overload from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.inputs import PromptStrictInputs +from vllm.inputs import (PromptInputs, PromptStrictInputs, + parse_and_batch_prompt) from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams +from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter +from vllm.utils import Counter, deprecate_kwargs class LLM: @@ -128,13 +130,96 @@ def set_tokenizer( ) -> None: self.llm_engine.tokenizer.tokenizer = tokenizer + @overload # DEPRECATED: single (prompt + optional token ids) + def generate( + self, + prompts: str, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # DEPRECATED: multi (prompt + optional token ids) + def generate( + self, + prompts: List[str], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # DEPRECATED: single (token ids + optional prompt) + def generate( + self, + prompts: Optional[str] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # DEPRECATED: multi (token ids + optional prompt) + def generate( + self, + prompts: Optional[List[str]] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # DEPRECATED: single or multi token ids [pos-only] + def generate( + self, + prompts: None, + sampling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload def generate( self, inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + /, # We may enable `inputs` keyword after removing the old API + *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, + ) -> List[RequestOutput]: + ... + + @deprecate_kwargs('prompts', 'prompt_token_ids', 'multi_modal_data') + def generate( + self, + prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + Optional[Union[str, List[str]]]] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -156,6 +241,96 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ + if prompt_token_ids is not None or multi_modal_data is not None: + return self._generate_v1( + prompts=prompts, # type: ignore + sampling_params=sampling_params, + prompt_token_ids=prompt_token_ids, + use_tqdm=use_tqdm, + lora_request=lora_request, + multi_modal_data=multi_modal_data, + ) + + return self._generate_v2( + inputs=prompts, # type: ignore + sampling_params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) + + # DEPRECATED + def _generate_v1( + self, + prompts: Optional[Union[str, List[str]]], + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]], + prompt_token_ids: Optional[Union[List[int], List[List[int]]]], + use_tqdm: bool, + lora_request: Optional[LoRARequest], + multi_modal_data: Optional[MultiModalData], + ) -> List[RequestOutput]: + # skip_tokenizer_init is now checked in engine + + if prompts is not None: + prompts = [p["text"] for p in parse_and_batch_prompt(prompts)] + if prompt_token_ids is not None: + prompt_token_ids = [ + p["text"] for p in parse_and_batch_prompt(prompt_token_ids) + ] + + num_requests = None + if prompts is not None: + num_requests = len(prompts) + if prompt_token_ids is not None: + if (num_requests is not None + and num_requests != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + + num_requests = len(prompt_token_ids) + if num_requests is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + + inputs: List[PromptInputs] = [] + for i in range(num_requests): + if prompts is not None: + if prompt_token_ids is not None: + inputs.append({ + "prompt": prompts[i], + "prompt_token_ids": prompt_token_ids[i], + "multi_modal_data": multi_modal_data, + }) + else: + inputs.append({ + "prompt": prompts[i], + "multi_modal_data": multi_modal_data, + }) + else: + if prompt_token_ids is not None: + inputs.append({ + "prompt_token_ids": prompt_token_ids[i], + "multi_modal_data": multi_modal_data, + }) + else: + raise AssertionError + + # sampling_params is now checked in _generate_v2 + return self._generate_v2( + inputs, + sampling_params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) + + def _generate_v2( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]], + use_tqdm: bool, + lora_request: Optional[LoRARequest], + ) -> List[RequestOutput]: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. inputs = [inputs] @@ -183,7 +358,7 @@ def generate( def _add_request( self, - inputs: PromptStrictInputs, + inputs: PromptInputs, sampling_params: SamplingParams, lora_request: Optional[LoRARequest] = None, ) -> None: diff --git a/vllm/inputs.py b/vllm/inputs.py index bd61f959eeb6e..2b5ea1c0f3828 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -1,9 +1,69 @@ -from typing import TYPE_CHECKING, List, Optional, TypedDict, Union +from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, + TypedDict, Union, cast, overload) if TYPE_CHECKING: from vllm.sequence import MultiModalData +class ParsedString(TypedDict): + text: str + is_tokens: Literal[False] + + +class ParsedTokens(TypedDict): + text: List[int] + is_tokens: Literal[True] + + +# https://github.com/vllm-project/vllm/pull/4028 +@overload +def parse_and_batch_prompt( + prompt: Union[str, List[str]]) -> Sequence[ParsedString]: + ... + + +@overload +def parse_and_batch_prompt( + prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]: + ... + + +def parse_and_batch_prompt( + prompt: Union[str, List[str], List[int], List[List[int]]], +) -> Union[Sequence[ParsedString], Sequence[ParsedTokens]]: + if isinstance(prompt, str): + # case 1: a string + return [ParsedString(text=prompt, is_tokens=False)] + + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0], str): + # case 2: array of strings + return [ + ParsedString(text=elem, is_tokens=False) + for elem in cast(List[str], prompt) + ] + if isinstance(prompt[0], int): + # case 3: array of tokens + elem = cast(List[int], prompt) + return [ParsedTokens(text=elem, is_tokens=True)] + if isinstance(prompt[0], list): + if len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0][0], int): + # case 4: array of token arrays + return [ + ParsedTokens(text=elem, is_tokens=True) + for elem in cast(List[List[int]], prompt) + ] + + raise ValueError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + class MultiModalPrompt(TypedDict, total=False): multi_modal_data: Optional["MultiModalData"] """Multi modal data.""" diff --git a/vllm/utils.py b/vllm/utils.py index ce55253ce2199..784b9bb29db8e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -11,7 +11,7 @@ import uuid import warnings from collections import defaultdict -from functools import lru_cache, partial +from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, OrderedDict, Tuple, TypeVar, @@ -638,3 +638,29 @@ def enable_trace_function_call_for_thread() -> None: filename) os.makedirs(os.path.dirname(log_path), exist_ok=True) enable_trace_function_call(log_path) + + +F = TypeVar('F', bound=Callable[..., Any]) + + +def deprecate_kwargs(*kws: str) -> Callable[[F], F]: + + def wrapper(fn: F) -> F: + + @wraps(fn) + def inner(*args, **kwargs): + deprecated_kws = {k for k in kwargs if k in kws} + if deprecated_kws: + warnings.warn( + DeprecationWarning( + f"The keyword arguments {deprecated_kws}" + " are deprecated and will be removed in " + "a future update."), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper From 703d318dea3479cd5f1edc8ecbe8beec2860cfa0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 May 2024 15:16:32 +0000 Subject: [PATCH 08/43] Add tests to ensure old API still works - To facilitate equality tests, `CompletionOutput` is now a dataclass --- tests/entrypoints/__init__.py | 0 tests/entrypoints/test_llm_generate.py | 124 +++++++++++++++++++++---- vllm/outputs.py | 29 ++---- vllm/utils.py | 10 +- 4 files changed, 123 insertions(+), 40 deletions(-) create mode 100644 tests/entrypoints/__init__.py diff --git a/tests/entrypoints/__init__.py b/tests/entrypoints/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 5e8b7ca4d9977..42bd6ef39c440 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -1,21 +1,113 @@ +from typing import List + import pytest -from vllm import LLM, SamplingParams +from vllm import LLM, RequestOutput, SamplingParams + +from ..conftest import cleanup + +MODEL_NAME = "facebook/opt-125m" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +TOKEN_IDS = [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], +] -def test_multiple_sampling_params(): - llm = LLM(model="facebook/opt-125m", +@pytest.fixture(scope="module") +def llm(): + yield LLM(model="facebook/opt-125m", max_num_batched_tokens=4096, + enforce_eager=True, tensor_parallel_size=1) - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] + cleanup() + + +def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=prompt, + sampling_params=sampling_params) + + v2_output = llm.generate(prompt, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate({"prompt": prompt}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params) + + v2_output = llm.generate({"prompt_token_ids": prompt_token_ids}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=PROMPTS, + sampling_params=sampling_params) + + v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate( + [{ + "prompt": p + } for p in PROMPTS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=TOKEN_IDS, + sampling_params=sampling_params) + + v2_output = llm.generate( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) + +@pytest.mark.skip_global_cleanup +def test_multiple_sampling_params(llm: LLM): sampling_params = [ SamplingParams(temperature=0.01, top_p=0.95), SamplingParams(temperature=0.3, top_p=0.95), @@ -24,18 +116,18 @@ def test_multiple_sampling_params(): ] # Multiple SamplingParams should be matched with each prompt - outputs = llm.generate(prompts, sampling_params=sampling_params) - assert len(prompts) == len(outputs) + outputs = llm.generate(PROMPTS, sampling_params=sampling_params) + assert len(PROMPTS) == len(outputs) # Exception raised, if the size of params does not match the size of prompts with pytest.raises(ValueError): - outputs = llm.generate(prompts, sampling_params=sampling_params[:3]) + outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3]) # Single SamplingParams should be applied to every prompt single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95) - outputs = llm.generate(prompts, sampling_params=single_sampling_params) - assert len(prompts) == len(outputs) + outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params) + assert len(PROMPTS) == len(outputs) # sampling_params is None, default params should be applied - outputs = llm.generate(prompts, sampling_params=None) - assert len(prompts) == len(outputs) \ No newline at end of file + outputs = llm.generate(PROMPTS, sampling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/vllm/outputs.py b/vllm/outputs.py index 78b70dfe107e3..f137c9e89a673 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,4 +1,5 @@ import time +from dataclasses import dataclass from typing import List, Optional, Union from vllm.lora.request import LoRARequest @@ -6,6 +7,7 @@ SequenceGroup, SequenceStatus) +@dataclass class CompletionOutput: """The output data of one completion output of a request. @@ -24,25 +26,14 @@ class CompletionOutput: lora_request: The LoRA request that was used to generate the output. """ - def __init__( - self, - index: int, - text: str, - token_ids: List[int], - cumulative_logprob: float, - logprobs: Optional[SampleLogprobs], - finish_reason: Optional[str] = None, - stop_reason: Union[int, str, None] = None, - lora_request: Optional[LoRARequest] = None, - ) -> None: - self.index = index - self.text = text - self.token_ids = token_ids - self.cumulative_logprob = cumulative_logprob - self.logprobs = logprobs - self.finish_reason = finish_reason - self.stop_reason = stop_reason - self.lora_request = lora_request + index: int + text: str + token_ids: List[int] + cumulative_logprob: float + logprobs: Optional[SampleLogprobs] + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + lora_request: Optional[LoRARequest] = None def finished(self) -> bool: return self.finish_reason is not None diff --git a/vllm/utils.py b/vllm/utils.py index 784b9bb29db8e..a3fb89612e54b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -644,18 +644,18 @@ def enable_trace_function_call_for_thread() -> None: def deprecate_kwargs(*kws: str) -> Callable[[F], F]: + deprecated_kws = set(kws) def wrapper(fn: F) -> F: @wraps(fn) def inner(*args, **kwargs): - deprecated_kws = {k for k in kwargs if k in kws} - if deprecated_kws: + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: warnings.warn( DeprecationWarning( - f"The keyword arguments {deprecated_kws}" - " are deprecated and will be removed in " - "a future update."), + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update."), stacklevel=3, # The inner function takes up one level ) From 19d85f990bd0878cc6ad151d76bad3d282d0c674 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 May 2024 17:21:02 +0000 Subject: [PATCH 09/43] Let all entrypoints tests be run at the same time --- .buildkite/test-pipeline.yaml | 4 +--- pyproject.toml | 5 +++++ tests/async_engine/test_openapi_server_ray.py | 4 ++-- tests/entrypoints/test_llm_generate.py | 5 +++-- tests/entrypoints/test_openai_server.py | 10 ++++------ tests/entrypoints/test_server_oot_registration.py | 7 ++++--- 6 files changed, 19 insertions(+), 16 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 60b57ea7a5f78..5f569693e0af6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -54,9 +54,7 @@ steps: - label: Entrypoints Test commands: - pytest -v -s test_inputs.py - # these tests have to be separated, because each one will allocate all posible GPU memory - - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - - pytest -v -s entrypoints/test_server_oot_registration.py + - pytest -v -s entrypoints - label: Examples Test working_dir: "/vllm-workspace/examples" diff --git a/pyproject.toml b/pyproject.toml index 6a448defc16e1..ead64b7436121 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,3 +65,8 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt" [tool.isort] use_parentheses = true skip_gitignore = true + +[tool.pytest.ini_options] +markers = [ + "skip_global_cleanup" +] diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 4b97af88012b9..2a754f5c4ccab 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -55,7 +55,7 @@ def __del__(self): self.proc.terminate() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(): ray.init() server_runner = ServerRunner.remote([ @@ -74,7 +74,7 @@ def server(): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 42bd6ef39c440..fe3d3fdf9a93d 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -27,8 +27,9 @@ def llm(): yield LLM(model="facebook/opt-125m", max_num_batched_tokens=4096, - enforce_eager=True, - tensor_parallel_size=1) + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) cleanup() diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 1323dba469117..60411c9e767d1 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -121,7 +121,7 @@ def zephyr_lora_files(): return snapshot_download(repo_id=LORA_NAME) -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(zephyr_lora_files): ray.init() server_runner = ServerRunner.remote([ @@ -133,6 +133,8 @@ def server(zephyr_lora_files): "--max-model-len", "8192", "--enforce-eager", + "--gpu-memory-utilization", + "0.75", # lora config below "--enable-lora", "--lora-modules", @@ -150,7 +152,7 @@ def server(zephyr_lora_files): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", @@ -888,7 +890,3 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): assert ("greater_than_equal" in exc_info.value.message or "less_than_equal" in exc_info.value.message) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py index 22e65bf7e7da1..dd43a5cf0a248 100644 --- a/tests/entrypoints/test_server_oot_registration.py +++ b/tests/entrypoints/test_server_oot_registration.py @@ -26,15 +26,16 @@ def server_function(port): # register our dummy model ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) sys.argv = ["placeholder.py"] + \ - ("--model facebook/opt-125m --dtype" - f" float32 --api-key token-abc123 --port {port}").split() + ("--model facebook/opt-125m --gpu-memory-utilization 0.10 " + f"--dtype float32 --api-key token-abc123 --port {port}").split() import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') def test_oot_registration_for_api_server(): port = get_open_port() - server = multiprocessing.Process(target=server_function, args=(port, )) + ctx = multiprocessing.get_context("spawn") + server = ctx.Process(target=server_function, args=(port, )) server.start() client = OpenAI( base_url=f"http://localhost:{port}/v1", From 5759dfa619f0310b5ccb2957a2d259a90254f87e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 03:01:39 +0000 Subject: [PATCH 10/43] Add tests for LLM.encode and fix corresponding bugs --- tests/entrypoints/test_llm_encode.py | 137 +++++++++++++++++++++++++ tests/entrypoints/test_llm_generate.py | 2 +- vllm/outputs.py | 9 +- 3 files changed, 141 insertions(+), 7 deletions(-) create mode 100644 tests/entrypoints/test_llm_encode.py diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py new file mode 100644 index 0000000000000..b6ae4bc498dc6 --- /dev/null +++ b/tests/entrypoints/test_llm_encode.py @@ -0,0 +1,137 @@ +from typing import List + +import pytest + +from vllm import LLM, EmbeddingRequestOutput, PoolingParams + +from ..conftest import cleanup + +MODEL_NAME = "intfloat/e5-mistral-7b-instruct" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +TOKEN_IDS = [ + # Using ID={0, 1, 2, 3} results in NaN values, + # so we add this offset of 1000 + [1000], + [1000, 1001], + [1000, 1002, 1001], + [1000, 1003, 1001, 1002], +] + + +@pytest.fixture(scope="module") +def llm(): + yield LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.90, + enforce_eager=True) + + cleanup() + + +def assert_outputs_equal(o1: List[EmbeddingRequestOutput], + o2: List[EmbeddingRequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=prompt, + pooling_params=pooling_params) + + v2_output = llm.encode(prompt, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode({"prompt": prompt}, + pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=prompt_token_ids, + pooling_params=pooling_params) + + v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, + pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=PROMPTS, + pooling_params=pooling_params) + + v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode( + [{ + "prompt": p + } for p in PROMPTS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, + pooling_params=pooling_params) + + v2_output = llm.encode( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_multiple_pooling_params(llm: LLM): + pooling_params = [ + PoolingParams(), + PoolingParams(), + PoolingParams(), + PoolingParams(), + ] + + # Multiple PoolingParams should be matched with each prompt + outputs = llm.encode(PROMPTS, pooling_params=pooling_params) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3]) + + # Single PoolingParams should be applied to every prompt + single_pooling_params = PoolingParams() + outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params) + assert len(PROMPTS) == len(outputs) + + # pooling_params is None, default params should be applied + outputs = llm.encode(PROMPTS, pooling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index fe3d3fdf9a93d..8ee08f8e83961 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -25,7 +25,7 @@ @pytest.fixture(scope="module") def llm(): - yield LLM(model="facebook/opt-125m", + yield LLM(model=MODEL_NAME, max_num_batched_tokens=4096, tensor_parallel_size=1, gpu_memory_utilization=0.10, diff --git a/vllm/outputs.py b/vllm/outputs.py index 8bf3e236d532d..49f526b5f9300 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -48,6 +48,7 @@ def __repr__(self) -> str: f"stop_reason={self.stop_reason})") +@dataclass class EmbeddingOutput: """The output data of one completion output of a request. @@ -56,15 +57,11 @@ class EmbeddingOutput: length of vector depends on the model as listed in the embedding guide. """ - def __init__( - self, - embedding: List[float], - ) -> None: - self.embedding = embedding + embedding: List[float] def __repr__(self) -> str: return (f"EmbeddingOutput(" - f"embedding={len(self.embedding)}") + f"embedding={len(self.embedding)})") class RequestOutput: From cc4bfb5416957336833ecd439d4da51d95b084e5 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 03:03:41 +0000 Subject: [PATCH 11/43] Apply formatter --- tests/entrypoints/test_llm_encode.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index b6ae4bc498dc6..4cf8b7fbafa8e 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -47,14 +47,12 @@ def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): pooling_params = PoolingParams() with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=prompt, - pooling_params=pooling_params) + v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) v2_output = llm.encode(prompt, pooling_params=pooling_params) assert_outputs_equal(v1_output, v2_output) - v2_output = llm.encode({"prompt": prompt}, - pooling_params=pooling_params) + v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) assert_outputs_equal(v1_output, v2_output) @@ -66,10 +64,10 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): v1_output = llm.encode(prompt_token_ids=prompt_token_ids, - pooling_params=pooling_params) + pooling_params=pooling_params) v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, - pooling_params=pooling_params) + pooling_params=pooling_params) assert_outputs_equal(v1_output, v2_output) @@ -78,8 +76,7 @@ def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): pooling_params = PoolingParams() with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=PROMPTS, - pooling_params=pooling_params) + v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) assert_outputs_equal(v1_output, v2_output) @@ -99,7 +96,7 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, - pooling_params=pooling_params) + pooling_params=pooling_params) v2_output = llm.encode( [{ From d5c9731f0d64b3bdd0c03df621a1f45b954a9eae Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 03:10:22 +0000 Subject: [PATCH 12/43] Rename `_add_requests` to `_validate_and_add_requests` to be more similar to the original `_validate_and_prepare_requests` --- vllm/entrypoints/llm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a6175ff4485d1..59709e050325b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -265,7 +265,7 @@ def generate( # Use default sampling params. sampling_params = SamplingParams() - return self._add_requests( + return self._validate_and_add_requests( inputs=inputs, params=sampling_params, use_tqdm=use_tqdm, @@ -395,7 +395,7 @@ def encode( # Use default pooling params. pooling_params = PoolingParams() - return self._add_requests( + return self._validate_and_add_requests( inputs=inputs, params=pooling_params, use_tqdm=use_tqdm, @@ -458,7 +458,7 @@ def _convert_v1_inputs( return inputs @overload - def _add_requests( + def _validate_and_add_requests( self, inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], params: Union[SamplingParams, Sequence[SamplingParams]], @@ -468,7 +468,7 @@ def _add_requests( ... @overload - def _add_requests( # type: ignore[misc] + def _validate_and_add_requests( # type: ignore[misc] self, inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], params: Union[PoolingParams, Sequence[PoolingParams]], @@ -477,7 +477,7 @@ def _add_requests( # type: ignore[misc] ) -> List[EmbeddingRequestOutput]: ... - def _add_requests( + def _validate_and_add_requests( self, inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, From 4f218a52ec830b6d39436ba84d7e2404851b1c18 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 03:15:52 +0000 Subject: [PATCH 13/43] Separate `entrypoints` tests into two groups --- .buildkite/test-pipeline.yaml | 3 ++- tests/entrypoints/openai/test_serving_chat.py | 4 +++ tests/entrypoints/test_guided_processors.py | 2 ++ tests/entrypoints/test_llm_encode.py | 2 ++ tests/entrypoints/test_llm_generate.py | 2 ++ tests/entrypoints/test_openai_server.py | 27 ++++++++++++++++++- .../test_server_oot_registration.py | 3 +++ 7 files changed, 41 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 176fa1d39db46..f0bab7c87ad0f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -52,7 +52,8 @@ steps: - label: Entrypoints Test commands: - pytest -v -s test_inputs.py - - pytest -v -s entrypoints + - pytest -v -s entrypoints -m llm + - pytest -v -s entrypoints -m openai - label: Examples Test working_dir: "/vllm-workspace/examples" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 74b49726734b5..c45f02fe564a3 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,11 +1,15 @@ import asyncio from dataclasses import dataclass +import pytest + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" +pytestmark = pytest.mark.openai + @dataclass class MockModelConfig: diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 41c871ca40bc8..5d4163e96fd87 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -52,6 +52,8 @@ TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") +pytestmark = pytest.mark.openai + def test_guided_logits_processors(): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index 4cf8b7fbafa8e..c9833b0c315bf 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -24,6 +24,8 @@ [1000, 1003, 1001, 1002], ] +pytestmark = pytest.mark.llm + @pytest.fixture(scope="module") def llm(): diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 8ee08f8e83961..e21e4b8136746 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -22,6 +22,8 @@ [0, 3, 1, 2], ] +pytestmark = pytest.mark.llm + @pytest.fixture(scope="module") def llm(): diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 75f0fb0be4f08..7050cc8ebe3d1 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -71,7 +71,7 @@ "Swift", "Kotlin" ] -pytestmark = pytest.mark.asyncio +pytestmark = pytest.mark.openai @pytest.fixture(scope="session") @@ -138,6 +138,7 @@ def client(): yield client +@pytest.mark.asyncio async def test_check_models(server, client: openai.AsyncOpenAI): models = await client.models.list() models = models.data @@ -149,6 +150,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI): assert lora_models[1].id == "zephyr-lora2" +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -180,6 +182,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, completion.choices[0].text) >= 5 +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -201,6 +204,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, assert choice.logprobs.top_logprobs is None +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -245,6 +249,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, model_name: str): @@ -300,6 +305,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -337,6 +343,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -387,6 +394,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -440,6 +448,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI, assert texts[0] == texts[1] +@pytest.mark.asyncio async def test_logits_bias(server, client: openai.AsyncOpenAI): prompt = "Hello, my name is" max_tokens = 5 @@ -487,6 +496,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_json_completion(server, client: openai.AsyncOpenAI, @@ -509,6 +519,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI, jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_json_chat(server, client: openai.AsyncOpenAI, @@ -555,6 +566,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI, assert json1["age"] != json2["age"] +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, @@ -575,6 +587,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, @@ -612,6 +625,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, assert ip1 != ip2 +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, @@ -631,6 +645,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, assert completion.choices[i].text in TEST_CHOICE +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, @@ -669,6 +684,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, assert choice1 != choice2 +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, @@ -704,6 +720,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, @@ -734,6 +751,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, for token, logprob in token_dict.items()) +@pytest.mark.asyncio async def test_response_format_json_object(server, client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( @@ -751,6 +769,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI): assert loaded == {"result": 2}, loaded +@pytest.mark.asyncio async def test_extra_fields(server, client: openai.AsyncOpenAI): with pytest.raises(BadRequestError) as exc_info: await client.chat.completions.create( @@ -766,6 +785,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI): assert "extra_forbidden" in exc_info.value.message +@pytest.mark.asyncio async def test_complex_message_content(server, client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, @@ -785,6 +805,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI): assert content == "2" +@pytest.mark.asyncio async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement @@ -819,6 +840,7 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI): assert content.strip() == ground_truth +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -850,6 +872,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, assert len(logprobs.tokens) > 5 +@pytest.mark.asyncio async def test_long_seed(server, client: openai.AsyncOpenAI): for seed in [ torch.iinfo(torch.long).min - 1, @@ -869,6 +892,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) +@pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [EMBEDDING_MODEL_NAME], @@ -907,6 +931,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 5 +@pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [EMBEDDING_MODEL_NAME], diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py index dd43a5cf0a248..52dc1a0b898de 100644 --- a/tests/entrypoints/test_server_oot_registration.py +++ b/tests/entrypoints/test_server_oot_registration.py @@ -2,6 +2,7 @@ import sys import time +import pytest import torch from openai import OpenAI, OpenAIError @@ -10,6 +11,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import get_open_port +pytestmark = pytest.mark.openai + class MyOPTForCausalLM(OPTForCausalLM): From a9201d0251790b0ebdb6981d4f7b39ab149ce2f0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 07:17:59 +0000 Subject: [PATCH 14/43] Fix memory profiling error --- .buildkite/test-pipeline.yaml | 3 ++- pyproject.toml | 5 ++++- tests/entrypoints/test_llm_encode.py | 9 +++------ tests/entrypoints/test_llm_generate.py | 7 ++----- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f0bab7c87ad0f..cb4e1ba935880 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -52,7 +52,8 @@ steps: - label: Entrypoints Test commands: - pytest -v -s test_inputs.py - - pytest -v -s entrypoints -m llm + - pytest -v -s entrypoints -m llm_generate + - pytest -v -s entrypoints -m llm_encode - pytest -v -s entrypoints -m openai - label: Examples Test diff --git a/pyproject.toml b/pyproject.toml index ead64b7436121..97ff04f854ad9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,5 +68,8 @@ skip_gitignore = true [tool.pytest.ini_options] markers = [ - "skip_global_cleanup" + "skip_global_cleanup", + "llm_encode: run tests for vLLM embedding API only", + "llm_generate: run tests for vLLM generate API only", + "openai: run tests for OpenAI API only", ] diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index c9833b0c315bf..fd1995a71f7dd 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -4,8 +4,6 @@ from vllm import LLM, EmbeddingRequestOutput, PoolingParams -from ..conftest import cleanup - MODEL_NAME = "intfloat/e5-mistral-7b-instruct" PROMPTS = [ @@ -24,19 +22,18 @@ [1000, 1003, 1001, 1002], ] -pytestmark = pytest.mark.llm +pytestmark = pytest.mark.llm_encode @pytest.fixture(scope="module") def llm(): + # pytest caches the fixture so we cannot GC it yield LLM(model=MODEL_NAME, max_num_batched_tokens=32768, tensor_parallel_size=1, - gpu_memory_utilization=0.90, + gpu_memory_utilization=0.75, enforce_eager=True) - cleanup() - def assert_outputs_equal(o1: List[EmbeddingRequestOutput], o2: List[EmbeddingRequestOutput]): diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index e21e4b8136746..b973cebea4e71 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -4,8 +4,6 @@ from vllm import LLM, RequestOutput, SamplingParams -from ..conftest import cleanup - MODEL_NAME = "facebook/opt-125m" PROMPTS = [ @@ -22,19 +20,18 @@ [0, 3, 1, 2], ] -pytestmark = pytest.mark.llm +pytestmark = pytest.mark.llm_generate @pytest.fixture(scope="module") def llm(): + # pytest caches the fixture so we cannot GC it yield LLM(model=MODEL_NAME, max_num_batched_tokens=4096, tensor_parallel_size=1, gpu_memory_utilization=0.10, enforce_eager=True) - cleanup() - def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] From ceebfa684c9bbb1401a5d1d042471b23635e5798 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 15 May 2024 02:04:01 +0000 Subject: [PATCH 15/43] Fix memory usage for embedding server --- tests/entrypoints/test_openai_server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 7050cc8ebe3d1..2944d887d8896 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -120,9 +120,11 @@ def embedding_server(zephyr_lora_files): # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", + "--enforce-eager", + "--gpu-memory-utilization", + "0.75", "--max-model-len", "8192", - "--enforce-eager", ]) ray.get(server_runner.ready.remote()) yield server_runner From 7d991cde83a872641cf58862a57a0f3697866e42 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 15 May 2024 02:43:16 +0000 Subject: [PATCH 16/43] Update embeddings API to use new imputs --- vllm/engine/async_llm_engine.py | 8 ++-- vllm/entrypoints/openai/serving_embedding.py | 42 ++++++++++++-------- vllm/entrypoints/openai/serving_engine.py | 6 ++- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d4991df122325..f4ae3fe64e85b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -584,7 +584,7 @@ async def add_request( async def generate( self, inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], + sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> AsyncIterator[RequestOutput]: @@ -596,9 +596,7 @@ async def generate( Args: inputs: The inputs to the LLM. - params: Parameters for sampling or pooling. - :class:`~vllm.SamplingParams` for text generation. - :class:`~vllm.PoolingParams` for pooling. + sampling_params: The sampling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. @@ -652,7 +650,7 @@ async def generate( async for output in self.process_request( request_id, inputs, - params, + sampling_params, lora_request=lora_request, ): yield output diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 7a57be0c88915..5a3448de3d7a4 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,5 +1,5 @@ import time -from typing import AsyncIterator, List, Tuple +from typing import AsyncIterator, List, Optional, Tuple from fastapi import Request @@ -100,11 +100,16 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_ids, prompt_text = prompt_formats - generators.append( - self.engine.generate(prompt_text, - pooling_params, - f"{request_id}-{i}", - prompt_token_ids=prompt_ids)) + generator = self.engine.encode( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + pooling_params, + f"{request_id}-{i}", + ) + + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -113,16 +118,21 @@ async def create_embedding(self, request: EmbeddingRequest, int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) # Non-streaming response - final_res_batch: EmbeddingRequestOutput = [None] * len(prompts) - async for i, res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") - # TODO: Use a vllm-specific Validation Error - return self.create_error_response("Client disconnected") - final_res_batch[i] = res - response = request_output_to_embedding_response( - final_res_batch, request_id, created_time, model_name) + final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch = [None] * len(prompts) + try: + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + # TODO: Use a vllm-specific Validation Error + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = request_output_to_embedding_response( + final_res_batch, request_id, created_time, model_name) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) return response diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 58a1c2f7e73fe..a50d91e8d4fd4 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -142,7 +142,8 @@ def create_streaming_error_response( return json_str async def _check_model( - self, request: Union[CompletionRequest, ChatCompletionRequest] + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] ) -> Optional[ErrorResponse]: if request.model in self.served_model_names: return None @@ -154,7 +155,8 @@ async def _check_model( status_code=HTTPStatus.NOT_FOUND) def _maybe_get_lora( - self, request: Union[CompletionRequest, ChatCompletionRequest] + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] ) -> Optional[LoRARequest]: if request.model in self.served_model_names: return None From 30975825f3d1259277aade493d12e3e081d625ca Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 20 May 2024 03:46:41 +0000 Subject: [PATCH 17/43] Merge `llm` groups back into one by enabling gc --- .buildkite/test-pipeline.yaml | 3 +-- pyproject.toml | 3 +-- tests/entrypoints/test_llm_encode.py | 15 ++++++++++++--- tests/entrypoints/test_llm_generate.py | 15 ++++++++++++--- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f1deb7fdcf698..206fb814abf5a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -63,8 +63,7 @@ steps: #mirror_hardwares: [amd] commands: - pytest -v -s test_inputs.py - - pytest -v -s entrypoints -m llm_generate - - pytest -v -s entrypoints -m llm_encode + - pytest -v -s entrypoints -m llm - pytest -v -s entrypoints -m openai - label: Examples Test diff --git a/pyproject.toml b/pyproject.toml index c529f51ab93de..ab3fbfc92642c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,6 @@ skip_gitignore = true [tool.pytest.ini_options] markers = [ "skip_global_cleanup", - "llm_encode: run tests for vLLM embedding API only", - "llm_generate: run tests for vLLM generate API only", + "llm: run tests for vLLM API only", "openai: run tests for OpenAI API only", ] diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index fd1995a71f7dd..24da218b5adb2 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -1,9 +1,12 @@ +import weakref from typing import List import pytest from vllm import LLM, EmbeddingRequestOutput, PoolingParams +from ..conftest import cleanup + MODEL_NAME = "intfloat/e5-mistral-7b-instruct" PROMPTS = [ @@ -22,18 +25,24 @@ [1000, 1003, 1001, 1002], ] -pytestmark = pytest.mark.llm_encode +pytestmark = pytest.mark.llm @pytest.fixture(scope="module") def llm(): - # pytest caches the fixture so we cannot GC it - yield LLM(model=MODEL_NAME, + # pytest caches the fixture so we use weakref for garbage collection to work + llm = LLM(model=MODEL_NAME, max_num_batched_tokens=32768, tensor_parallel_size=1, gpu_memory_utilization=0.75, enforce_eager=True) + yield weakref.proxy(llm) + + del llm + + cleanup() + def assert_outputs_equal(o1: List[EmbeddingRequestOutput], o2: List[EmbeddingRequestOutput]): diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index b973cebea4e71..4c2e52e64d54c 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -1,9 +1,12 @@ +import weakref from typing import List import pytest from vllm import LLM, RequestOutput, SamplingParams +from ..conftest import cleanup + MODEL_NAME = "facebook/opt-125m" PROMPTS = [ @@ -20,18 +23,24 @@ [0, 3, 1, 2], ] -pytestmark = pytest.mark.llm_generate +pytestmark = pytest.mark.llm @pytest.fixture(scope="module") def llm(): - # pytest caches the fixture so we cannot GC it - yield LLM(model=MODEL_NAME, + # pytest caches the fixture so we use weakref for garbage collection to work + llm = LLM(model=MODEL_NAME, max_num_batched_tokens=4096, tensor_parallel_size=1, gpu_memory_utilization=0.10, enforce_eager=True) + yield weakref.proxy(llm) + + del llm + + cleanup() + def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] From 7bbd123dd23fd6c0266a24331f404244462ed7af Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 20 May 2024 10:11:50 +0000 Subject: [PATCH 18/43] Improve documentation for LLM/engine --- vllm/engine/async_llm_engine.py | 18 +++++++++--------- vllm/engine/llm_engine.py | 10 +++++----- vllm/entrypoints/llm.py | 6 ++++-- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3e684c883e1eb..b2f2fa02642d1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -291,15 +291,15 @@ async def check_health_async(self) -> None: class AsyncLLMEngine: - """An asynchronous wrapper for LLMEngine. + """An asynchronous wrapper for :class:`LLMEngine`. - This class is used to wrap the LLMEngine class to make it asynchronous. It - uses asyncio to create a background loop that keeps processing incoming - requests. The LLMEngine is kicked by the generate method when there - are requests in the waiting queue. The generate method yields the outputs - from the LLMEngine to the caller. + This class is used to wrap the :class:`LLMEngine` class to make it + asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The :class:`LLMEngine` is kicked by the + generate method when there are requests in the waiting queue. The generate + method yields the outputs from the :class:`LLMEngine` to the caller. - NOTE: For the comprehensive list of arguments, see `LLMEngine`. + NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`. Args: worker_use_ray: Whether to use Ray for model workers. Required for @@ -313,8 +313,8 @@ class AsyncLLMEngine: being printed in log. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. - *args: Arguments for LLMEngine. - *kwargs: Arguments for LLMEngine. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. """ _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5e64117ae6c5a..850c8096f58d3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -63,11 +63,11 @@ class LLMEngine: iteration-level scheduling and efficient memory management to maximize the serving throughput. - The `LLM` class wraps this class for offline batched inference and the - `AsyncLLMEngine` class wraps this class for online serving. + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. - NOTE: The config arguments are derived from the `EngineArgs` class. For the - comprehensive list of arguments, see `EngineArgs`. + NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs` + class. For the comprehensive list of arguments, see :ref:`engine_args`. Args: model_config: The configuration related to the LLM model. @@ -84,7 +84,7 @@ class LLMEngine: executor_class: The model executor class for managing distributed execution. log_stats: Whether to log statistics. - usage_context: Specified entry point, used for usage info collection + usage_context: Specified entry point, used for usage info collection. """ tokenizer: Optional[BaseTokenizerGroup] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 59709e050325b..89d49b4741ccc 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -29,8 +29,10 @@ class LLM: mechanism and efficient memory management. NOTE: This class is intended to be used for offline inference. For online - serving, use the `AsyncLLMEngine` class instead. - NOTE: For the comprehensive list of arguments, see `EngineArgs`. + serving, use the :class:`~vllm.AsyncLLMEngine` class instead. + + NOTE: For the comprehensive list of arguments, see + :class:`~vllm.EngineArgs`. Args: model: The name or path of a HuggingFace Transformers model. From 056eb6168989fc61fe625cde6cf2de1cd65765c1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 07:59:34 +0000 Subject: [PATCH 19/43] Direct readers to the `PromptInputs` class --- docs/source/index.rst | 1 + docs/source/offline_inference/llm.rst | 2 +- docs/source/offline_inference/llm_inputs.rst | 14 +++++ vllm/__init__.py | 4 ++ vllm/engine/async_llm_engine.py | 8 ++- vllm/engine/llm_engine.py | 6 +- vllm/entrypoints/llm.py | 4 +- vllm/inputs.py | 64 +++++++++++++------- 8 files changed, 74 insertions(+), 29 deletions(-) create mode 100644 docs/source/offline_inference/llm_inputs.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index bab00e28e4018..6383680f2b512 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -73,6 +73,7 @@ Documentation :caption: Offline Inference offline_inference/llm + offline_inference/llm_inputs offline_inference/sampling_params .. toctree:: diff --git a/docs/source/offline_inference/llm.rst b/docs/source/offline_inference/llm.rst index 1a443ea406994..83ba1b6987c6d 100644 --- a/docs/source/offline_inference/llm.rst +++ b/docs/source/offline_inference/llm.rst @@ -1,5 +1,5 @@ LLM Class -========== +========= .. autoclass:: vllm.LLM :members: diff --git a/docs/source/offline_inference/llm_inputs.rst b/docs/source/offline_inference/llm_inputs.rst new file mode 100644 index 0000000000000..31c3d16a3c8eb --- /dev/null +++ b/docs/source/offline_inference/llm_inputs.rst @@ -0,0 +1,14 @@ +LLM Inputs +========== + +.. autodata:: vllm.inputs.PromptStrictInputs + +.. autoclass:: vllm.inputs.TextPrompt + :show-inheritance: + :members: + :member-order: bysource + +.. autoclass:: vllm.inputs.TokensPrompt + :show-inheritance: + :members: + :member-order: bysource diff --git a/vllm/__init__.py b/vllm/__init__.py index 74674ca0d12af..a0e154d24087c 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,6 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -16,6 +17,9 @@ __all__ = [ "LLM", "ModelRegistry", + "PromptStrictInputs", + "TextPrompt", + "TokensPrompt", "SamplingParams", "RequestOutput", "CompletionOutput", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index b2f2fa02642d1..8212b9c6e2027 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -599,7 +599,9 @@ async def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. @@ -673,7 +675,9 @@ async def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 850c8096f58d3..89f99aa8f7098 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -152,8 +152,6 @@ def __init__( self.decoding_config = decoding_config or DecodingConfig() self.log_stats = log_stats - self.tokenizer: Optional[BaseTokenizerGroup] - if not self.model_config.skip_tokenizer_init: tokenizer = self._init_tokenizer() self.detokenizer = Detokenizer(tokenizer) @@ -446,7 +444,9 @@ def add_request( Args: request_id: The unique ID of the request. - inputs: The inputs to the LLM. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. :class:`~vllm.PoolingParams` for pooling. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 89d49b4741ccc..f84de86e4617b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -372,7 +372,9 @@ def encode( into a single list and pass it to this method. Args: - inputs: The inputs to the LLM. + inputs: The inputs to the LLM. You may pass a sequence of inputs for + batch inference. See :class:`~vllm.inputs.PromptStrictInputs` + for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. diff --git a/vllm/inputs.py b/vllm/inputs.py index 2b5ea1c0f3828..e4bdb18c2f49a 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -5,7 +5,7 @@ from vllm.sequence import MultiModalData -class ParsedString(TypedDict): +class ParsedText(TypedDict): text: str is_tokens: Literal[False] @@ -18,7 +18,7 @@ class ParsedTokens(TypedDict): # https://github.com/vllm-project/vllm/pull/4028 @overload def parse_and_batch_prompt( - prompt: Union[str, List[str]]) -> Sequence[ParsedString]: + prompt: Union[str, List[str]]) -> Sequence[ParsedText]: ... @@ -30,10 +30,10 @@ def parse_and_batch_prompt( def parse_and_batch_prompt( prompt: Union[str, List[str], List[int], List[List[int]]], -) -> Union[Sequence[ParsedString], Sequence[ParsedTokens]]: +) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: if isinstance(prompt, str): # case 1: a string - return [ParsedString(text=prompt, is_tokens=False)] + return [ParsedText(text=prompt, is_tokens=False)] if isinstance(prompt, list): if len(prompt) == 0: @@ -42,7 +42,7 @@ def parse_and_batch_prompt( if isinstance(prompt[0], str): # case 2: array of strings return [ - ParsedString(text=elem, is_tokens=False) + ParsedText(text=elem, is_tokens=False) for elem in cast(List[str], prompt) ] if isinstance(prompt[0], int): @@ -64,42 +64,62 @@ def parse_and_batch_prompt( "array of tokens, or array of token arrays") -class MultiModalPrompt(TypedDict, total=False): - multi_modal_data: Optional["MultiModalData"] - """Multi modal data.""" - +class TextPrompt(TypedDict): + """Schema for a text prompt.""" -class StringPrompt(MultiModalPrompt, TypedDict): prompt: str - """The prompt string.""" + """The input text to be tokenized before passing to the model.""" + + multi_modal_data: Optional["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + +class TokensPrompt(TypedDict): + """Schema for a tokenized prompt.""" -class TokensPrompt(MultiModalPrompt, TypedDict): prompt_token_ids: List[int] - """The token IDs of the prompt. If None, we use the - tokenizer to convert the prompts to token IDs.""" + """A list of token IDs to pass to the model.""" + + multi_modal_data: Optional["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ -class StringTokensPrompt(MultiModalPrompt, TypedDict): +class TextTokensPrompt(TypedDict): """It is assumed that :attr:`prompt` is consistent with :attr:`prompt_token_ids`. This is currently used in :class:`AsyncLLMEngine` for logging both the text and token IDs.""" prompt: str - """The prompt string.""" + """The prompt text.""" prompt_token_ids: List[int] """The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs.""" + multi_modal_data: Optional["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +PromptStrictInputs = Union[str, TextPrompt, TokensPrompt] +""" +The inputs to the LLM, which can take one of the following forms: -PromptStrictInputs = Union[str, StringPrompt, TokensPrompt] -"""The prompt string. More complex inputs should be represented by -:class:`StringPrompt` or :class:`TokensPrompt`.""" +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) +""" -PromptInputs = Union[str, StringPrompt, TokensPrompt, StringTokensPrompt] -"""As :const:`PromptStrictInputs` but additionally accepts -:class:`StringTokensPrompt`.""" +PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] +"""Same as :const:`PromptStrictInputs` but additionally accepts +:class:`TextTokensPrompt`.""" class LLMInputs(TypedDict): From b3b990a7ac93c08e7734b02a7bee7cae14711c81 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 08:18:55 +0000 Subject: [PATCH 20/43] Separate `_run_engine` from `_validate_and_add_requests` --- tests/lora/test_long_context.py | 4 +- tests/samplers/test_logits_processor.py | 4 +- tests/samplers/test_seeded_generate.py | 4 +- vllm/entrypoints/llm.py | 51 +++++++++---------------- 4 files changed, 23 insertions(+), 40 deletions(-) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 15189f421a539..3dd9b98ed911b 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -5,7 +5,7 @@ import pytest import vllm -from vllm import SamplingParams +from vllm import RequestOutput, SamplingParams from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora from vllm.lora.request import LoRARequest from vllm.model_executor.layers.rotary_embedding import ( @@ -100,7 +100,7 @@ def batched_generate( # Add requests to the engine and run the engine for request_data in requests_data: llm._add_request(**request_data) - outputs = llm._run_engine(use_tqdm=True) + outputs = llm._run_engine(RequestOutput, use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 0724622d5f3c7..1b63c1dab98d2 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -1,7 +1,7 @@ import pytest import torch -from vllm import SamplingParams +from vllm import RequestOutput, SamplingParams MODELS = ["facebook/opt-125m"] @@ -54,6 +54,6 @@ def pick_vllm(token_ids, logits): params=SamplingParams(max_tokens=max_tokens), ) - outputs = vllm_model.model._run_engine(False) + outputs = vllm_model.model._run_engine(RequestOutput, use_tqdm=False) assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index fef5ff3fb9e8e..fca2b0e05c335 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -8,7 +8,7 @@ import pytest -from vllm import SamplingParams +from vllm import RequestOutput, SamplingParams from vllm.model_executor.utils import set_random_seed MODEL = "facebook/opt-125m" @@ -59,7 +59,7 @@ def test_random_sample_with_seed( ): llm._add_request(prompt, params=params) - results = llm._run_engine(use_tqdm=False) + results = llm._run_engine(RequestOutput, use_tqdm=False) all_outputs = [[out.token_ids for out in output.outputs] for output in results] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f84de86e4617b..d21d54783b139 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Sequence, Union, cast, overload +from typing import (List, Optional, Sequence, Type, TypeVar, Union, cast, + overload) from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -18,6 +19,8 @@ logger = init_logger(__name__) +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -267,13 +270,15 @@ def generate( # Use default sampling params. sampling_params = SamplingParams() - return self._validate_and_add_requests( + self._validate_and_add_requests( inputs=inputs, params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, ) + return self._run_engine(RequestOutput, use_tqdm=use_tqdm) + @overload # DEPRECATED: single (prompt + optional token ids) def encode( self, @@ -399,13 +404,15 @@ def encode( # Use default pooling params. pooling_params = PoolingParams() - return self._validate_and_add_requests( + self._validate_and_add_requests( inputs=inputs, params=pooling_params, use_tqdm=use_tqdm, lora_request=lora_request, ) + return self._run_engine(EmbeddingRequestOutput, use_tqdm=use_tqdm) + # DEPRECATED def _convert_v1_inputs( self, @@ -461,26 +468,6 @@ def _convert_v1_inputs( return inputs - @overload - def _validate_and_add_requests( - self, - inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], - params: Union[SamplingParams, Sequence[SamplingParams]], - use_tqdm: bool, - lora_request: Optional[LoRARequest], - ) -> List[RequestOutput]: - ... - - @overload - def _validate_and_add_requests( # type: ignore[misc] - self, - inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], - params: Union[PoolingParams, Sequence[PoolingParams]], - use_tqdm: bool, - lora_request: Optional[LoRARequest], - ) -> List[EmbeddingRequestOutput]: - ... - def _validate_and_add_requests( self, inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], @@ -488,7 +475,7 @@ def _validate_and_add_requests( Sequence[PoolingParams]], use_tqdm: bool, lora_request: Optional[LoRARequest], - ) -> Union[List[RequestOutput], List[EmbeddingRequestOutput]]: + ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. inputs = [inputs] @@ -507,8 +494,6 @@ def _validate_and_add_requests( lora_request=lora_request, ) - return self._run_engine(use_tqdm) - def _add_request( self, inputs: PromptInputs, @@ -521,9 +506,8 @@ def _add_request( params, lora_request=lora_request) - def _run_engine( - self, use_tqdm: bool - ) -> Union[List[RequestOutput], List[EmbeddingRequestOutput]]: + def _run_engine(self, output_type: Type[_O], *, + use_tqdm: bool) -> List[_O]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -556,9 +540,8 @@ def _run_engine( # its previous requests. outputs = sorted(outputs, key=lambda x: int(x.request_id)) - if len(outputs) > 0: - first, *rest = outputs - assert all(isinstance(r, type(first)) for r in rest), ( - f"Expected all outputs to be of the same type {type(first)}") + if len(outputs) > 0 and not isinstance(outputs[0], output_type): + raise TypeError(f"Expected output type to be {output_type}, " + f"but found type {type(outputs[0])}") - return outputs # type: ignore + return cast(List[_O], outputs) From 2169defbff298b09a405a1bd124976adf7ac574f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 08:41:31 +0000 Subject: [PATCH 21/43] Add flag for deprecating legacy API --- tests/entrypoints/test_llm_encode.py | 22 ++++++------ tests/entrypoints/test_llm_generate.py | 22 ++++++------ vllm/entrypoints/llm.py | 49 ++++++++++++++++++-------- vllm/utils.py | 30 +++++++++++----- 4 files changed, 79 insertions(+), 44 deletions(-) diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index 24da218b5adb2..872707b54e3f7 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -30,16 +30,18 @@ @pytest.fixture(scope="module") def llm(): - # pytest caches the fixture so we use weakref for garbage collection to work - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True) - - yield weakref.proxy(llm) - - del llm + with LLM.deprecate_legacy_ctx(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True) + + yield weakref.proxy(llm) + + del llm cleanup() diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 4c2e52e64d54c..37d1ea7e8745b 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -28,16 +28,18 @@ @pytest.fixture(scope="module") def llm(): - # pytest caches the fixture so we use weakref for garbage collection to work - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=4096, - tensor_parallel_size=1, - gpu_memory_utilization=0.10, - enforce_eager=True) - - yield weakref.proxy(llm) - - del llm + with LLM.deprecate_legacy_ctx(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) + + yield weakref.proxy(llm) + + del llm cleanup() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d21d54783b139..2f76979ce5ebe 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,6 @@ -from typing import (List, Optional, Sequence, Type, TypeVar, Union, cast, - overload) +from contextlib import contextmanager +from typing import (ClassVar, List, Optional, Sequence, Type, TypeVar, Union, + cast, overload) from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -87,6 +88,18 @@ class LLM: disable_custom_all_reduce: See ParallelConfig """ + DEPRECATE_LEGACY: ClassVar[bool] = False + """A flag to toggle whether to deprecate the legacy generate/encode API.""" + + @staticmethod + @contextmanager + def deprecate_legacy_ctx(): + LLM.DEPRECATE_LEGACY = True + + yield + + LLM.DEPRECATE_LEGACY = False + def __init__( self, model: str, @@ -144,7 +157,7 @@ def set_tokenizer( ) -> None: self.llm_engine.tokenizer.tokenizer = tokenizer - @overload # DEPRECATED: single (prompt + optional token ids) + @overload # LEGACY: single (prompt + optional token ids) def generate( self, prompts: str, @@ -157,7 +170,7 @@ def generate( ) -> List[RequestOutput]: ... - @overload # DEPRECATED: multi (prompt + optional token ids) + @overload # LEGACY: multi (prompt + optional token ids) def generate( self, prompts: List[str], @@ -170,7 +183,7 @@ def generate( ) -> List[RequestOutput]: ... - @overload # DEPRECATED: single (token ids + optional prompt) + @overload # LEGACY: single (token ids + optional prompt) def generate( self, prompts: Optional[str] = None, @@ -184,7 +197,7 @@ def generate( ) -> List[RequestOutput]: ... - @overload # DEPRECATED: multi (token ids + optional prompt) + @overload # LEGACY: multi (token ids + optional prompt) def generate( self, prompts: Optional[List[str]] = None, @@ -198,7 +211,7 @@ def generate( ) -> List[RequestOutput]: ... - @overload # DEPRECATED: single or multi token ids [pos-only] + @overload # LEGACY: single or multi token ids [pos-only] def generate( self, prompts: None, @@ -223,7 +236,10 @@ def generate( ) -> List[RequestOutput]: ... - @deprecate_kwargs('prompts', 'prompt_token_ids', 'multi_modal_data') + @deprecate_kwargs('prompts', + 'prompt_token_ids', + 'multi_modal_data', + is_deprecated=lambda: LLM.DEPRECATE_LEGACY) def generate( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], @@ -279,7 +295,7 @@ def generate( return self._run_engine(RequestOutput, use_tqdm=use_tqdm) - @overload # DEPRECATED: single (prompt + optional token ids) + @overload # LEGACY: single (prompt + optional token ids) def encode( self, prompts: str, @@ -292,7 +308,7 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @overload # DEPRECATED: multi (prompt + optional token ids) + @overload # LEGACY: multi (prompt + optional token ids) def encode( self, prompts: List[str], @@ -305,7 +321,7 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @overload # DEPRECATED: single (token ids + optional prompt) + @overload # LEGACY: single (token ids + optional prompt) def encode( self, prompts: Optional[str] = None, @@ -319,7 +335,7 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @overload # DEPRECATED: multi (token ids + optional prompt) + @overload # LEGACY: multi (token ids + optional prompt) def encode( self, prompts: Optional[List[str]] = None, @@ -333,7 +349,7 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @overload # DEPRECATED: single or multi token ids [pos-only] + @overload # LEGACY: single or multi token ids [pos-only] def encode( self, prompts: None, @@ -358,7 +374,10 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @deprecate_kwargs('prompts', 'prompt_token_ids', 'multi_modal_data') + @deprecate_kwargs('prompts', + 'prompt_token_ids', + 'multi_modal_data', + is_deprecated=lambda: LLM.DEPRECATE_LEGACY) def encode( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], @@ -413,7 +432,7 @@ def encode( return self._run_engine(EmbeddingRequestOutput, use_tqdm=use_tqdm) - # DEPRECATED + # LEGACY def _convert_v1_inputs( self, prompts: Optional[Union[str, List[str]]], diff --git a/vllm/utils.py b/vllm/utils.py index 506f4f3ae2be6..4480c6c6960de 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -655,24 +655,36 @@ def enable_trace_function_call_for_thread() -> None: enable_trace_function_call(log_path) +def identity(value: T) -> T: + return value + + F = TypeVar('F', bound=Callable[..., Any]) -def deprecate_kwargs(*kws: str) -> Callable[[F], F]: +def deprecate_kwargs( + *kws: str, + is_deprecated: Union[bool, Callable[[], + bool]] = True) -> Callable[[F], F]: deprecated_kws = set(kws) + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + def wrapper(fn: F) -> F: @wraps(fn) def inner(*args, **kwargs): - deprecated_kwargs = kwargs.keys() & deprecated_kws - if deprecated_kwargs: - warnings.warn( - DeprecationWarning( - f"The keyword arguments {deprecated_kwargs} are " - "deprecated and will be removed in a future update."), - stacklevel=3, # The inner function takes up one level - ) + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + warnings.warn( + DeprecationWarning( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update." + ), + stacklevel=3, # The inner function takes up one level + ) return fn(*args, **kwargs) From 3dbded140243682140d6231fb273ca86495ad104 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 09:01:08 +0000 Subject: [PATCH 22/43] Add tests for `deprecate_kwargs` --- .buildkite/test-pipeline.yaml | 3 +++ tests/test_utils.py | 51 +++++++++++++++++++++++++++++++++++ tests/utils.py | 14 ++++++++++ 3 files changed, 68 insertions(+) create mode 100644 tests/test_utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 206fb814abf5a..af22e404361aa 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -109,6 +109,9 @@ steps: mirror_hardwares: [amd] command: pytest -v -s test_logits_processor.py +- label: Utils Test + command: pytest -v -s test_utils.py + - label: Worker Test mirror_hardwares: [amd] command: pytest -v -s worker diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000..7f84fc7f6a454 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,51 @@ +import pytest + +from vllm.utils import deprecate_kwargs + +from .utils import error_on_warning + + +def test_deprecate_kwargs_always(): + @deprecate_kwargs("old_arg", is_deprecated=True) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_never(): + @deprecate_kwargs("old_arg", is_deprecated=False) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_func(): + is_deprecated = True + + @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + is_deprecated = False + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) diff --git a/tests/utils.py b/tests/utils.py index 689d8c8c5ba8a..329842911e159 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,8 @@ import subprocess import sys import time +import warnings +from contextlib import contextmanager import ray import requests @@ -87,3 +89,15 @@ def multi_process_tensor_parallel( ray.get(refs) ray.shutdown() + + +@contextmanager +def error_on_warning(): + """ + Within the scope of this context manager, tests will fail if any warning + is emitted. + """ + with warnings.catch_warnings(): + warnings.simplefilter("error") + + yield From 8e20317bbfe9a25dc1b062cf058a68382e7d5f17 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 09:04:34 +0000 Subject: [PATCH 23/43] Apply formatter --- tests/test_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 7f84fc7f6a454..988dc5ba2bf29 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,7 @@ def test_deprecate_kwargs_always(): + @deprecate_kwargs("old_arg", is_deprecated=True) def dummy(*, old_arg: object = None, new_arg: object = None): pass @@ -18,6 +19,7 @@ def dummy(*, old_arg: object = None, new_arg: object = None): def test_deprecate_kwargs_never(): + @deprecate_kwargs("old_arg", is_deprecated=False) def dummy(*, old_arg: object = None, new_arg: object = None): pass @@ -41,7 +43,7 @@ def dummy(*, old_arg: object = None, new_arg: object = None): with error_on_warning(): dummy(new_arg=1) - + is_deprecated = False with error_on_warning(): From fdccaa21066010e7c8d265528e34f21979421109 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 09:04:47 +0000 Subject: [PATCH 24/43] Rename attribute to be less misleading --- vllm/entrypoints/llm.py | 4 ++-- vllm/inputs.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2f76979ce5ebe..8943929371ae9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -442,10 +442,10 @@ def _convert_v1_inputs( # skip_tokenizer_init is now checked in engine if prompts is not None: - prompts = [p["text"] for p in parse_and_batch_prompt(prompts)] + prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] if prompt_token_ids is not None: prompt_token_ids = [ - p["text"] for p in parse_and_batch_prompt(prompt_token_ids) + p["content"] for p in parse_and_batch_prompt(prompt_token_ids) ] num_requests = None diff --git a/vllm/inputs.py b/vllm/inputs.py index e4bdb18c2f49a..80011b6dd1d61 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -6,12 +6,12 @@ class ParsedText(TypedDict): - text: str + content: str is_tokens: Literal[False] class ParsedTokens(TypedDict): - text: List[int] + content: List[int] is_tokens: Literal[True] @@ -33,7 +33,7 @@ def parse_and_batch_prompt( ) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: if isinstance(prompt, str): # case 1: a string - return [ParsedText(text=prompt, is_tokens=False)] + return [ParsedText(content=prompt, is_tokens=False)] if isinstance(prompt, list): if len(prompt) == 0: @@ -42,13 +42,13 @@ def parse_and_batch_prompt( if isinstance(prompt[0], str): # case 2: array of strings return [ - ParsedText(text=elem, is_tokens=False) + ParsedText(content=elem, is_tokens=False) for elem in cast(List[str], prompt) ] if isinstance(prompt[0], int): # case 3: array of tokens elem = cast(List[int], prompt) - return [ParsedTokens(text=elem, is_tokens=True)] + return [ParsedTokens(content=elem, is_tokens=True)] if isinstance(prompt[0], list): if len(prompt[0]) == 0: raise ValueError("please provide at least one prompt") @@ -56,7 +56,7 @@ def parse_and_batch_prompt( if isinstance(prompt[0][0], int): # case 4: array of token arrays return [ - ParsedTokens(text=elem, is_tokens=True) + ParsedTokens(content=elem, is_tokens=True) for elem in cast(List[List[int]], prompt) ] From 77ee1c87707590ad015eb4636500a02277e30e83 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 06:50:39 +0000 Subject: [PATCH 25/43] Renable using `'fork'` start method and improve speed by using `torch.multiprocessing` wrapper instead of stdlib `multiprocessing` --- tests/entrypoints/test_server_oot_registration.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py index 52dc1a0b898de..3e55d7f4297fb 100644 --- a/tests/entrypoints/test_server_oot_registration.py +++ b/tests/entrypoints/test_server_oot_registration.py @@ -1,4 +1,3 @@ -import multiprocessing import sys import time @@ -37,7 +36,7 @@ def server_function(port): def test_oot_registration_for_api_server(): port = get_open_port() - ctx = multiprocessing.get_context("spawn") + ctx = torch.multiprocessing.get_context() server = ctx.Process(target=server_function, args=(port, )) server.start() client = OpenAI( From b1bcdd17002f940a9b64a64925f318090ebfbc0b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 07:52:51 +0000 Subject: [PATCH 26/43] Simplify logic of casting request output --- vllm/engine/async_llm_engine.py | 76 +++++++++++++++------------------ vllm/engine/llm_engine.py | 1 - vllm/entrypoints/llm.py | 23 ++++++---- 3 files changed, 49 insertions(+), 51 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8212b9c6e2027..7fe758c170b0c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,7 +2,7 @@ import time from functools import partial from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, - Set, Tuple, Type, Union, overload) + Set, Tuple, Type, TypeVar, Union) from transformers import PreTrainedTokenizer @@ -290,6 +290,9 @@ async def check_health_async(self) -> None: self.model_executor.check_health() +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + + class AsyncLLMEngine: """An asynchronous wrapper for :class:`LLMEngine`. @@ -653,10 +656,11 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in self.process_request( + async for output in self._process_request( request_id, inputs, sampling_params, + output_type=RequestOutput, lora_request=lora_request, ): yield output @@ -727,63 +731,53 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in self.process_request( + async for output in self._process_request( request_id, inputs, pooling_params, + output_type=EmbeddingRequestOutput, lora_request=lora_request, ): yield output - @overload - def process_request( - self, - request_id: str, - inputs: PromptInputs, - params: SamplingParams, - lora_request: Optional[LoRARequest] = None, - ) -> AsyncIterator[RequestOutput]: - ... - - @overload - def process_request( # type: ignore[misc] - self, - request_id: str, - inputs: PromptInputs, - params: PoolingParams, - lora_request: Optional[LoRARequest] = None, - ) -> AsyncIterator[EmbeddingRequestOutput]: - ... - - def process_request( + async def _process_request( self, request_id: str, inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], + *, + output_type: Type[_O], lora_request: Optional[LoRARequest] = None, - ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: + ) -> AsyncIterator[_O]: """Common logic to process requests with SamplingParams or PoolingParams.""" + arrival_time = time.time() - async def generator(): - arrival_time = time.time() + stream = await self.add_request( + request_id, + inputs, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) - stream = await self.add_request( - request_id, - inputs, - params, - arrival_time=arrival_time, - lora_request=lora_request, - ) + try: + is_first = True - try: - async for request_output in stream: - yield request_output - except (Exception, asyncio.CancelledError) as e: - self._abort(request_id) - raise e + async for request_output in stream: + # To improve performance, we only check the first result + if is_first: + if not isinstance(request_output, output_type): + raise TypeError( + f"Expected output of type {output_type}, " + f"but found type {type(request_output)}") + + is_first = False - return generator() + yield request_output # type: ignore + except (Exception, asyncio.CancelledError) as e: + self._abort(request_id) + raise e async def abort(self, request_id: str) -> None: """Abort a request. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 89f99aa8f7098..f4644a999745e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -346,7 +346,6 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer = get_tokenizer_group( self.parallel_config.tokenizer_pool_config, **init_kwargs) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8943929371ae9..eb48df5e4266a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -537,13 +537,24 @@ def _run_engine(self, output_type: Type[_O], *, postfix=f"Generation Speed: {0:.2f} toks/s", ) # Run the engine. - outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] + outputs: List[_O] = [] total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() + is_first = True + for output in step_outputs: + # To improve performance, we only check the first result + if is_first: + if not isinstance(outputs[0], output_type): + raise TypeError( + f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + is_first = False + if output.finished: - outputs.append(output) + outputs.append(output) # type: ignore if use_tqdm: if isinstance(output, RequestOutput): # Calculate tokens only for RequestOutput @@ -557,10 +568,4 @@ def _run_engine(self, output_type: Type[_O], *, # Sort the outputs by request ID. # This is necessary because some requests may be finished earlier than # its previous requests. - outputs = sorted(outputs, key=lambda x: int(x.request_id)) - - if len(outputs) > 0 and not isinstance(outputs[0], output_type): - raise TypeError(f"Expected output type to be {output_type}, " - f"but found type {type(outputs[0])}") - - return cast(List[_O], outputs) + return sorted(outputs, key=lambda x: int(x.request_id)) From 44b4681f42c74b4761b5ece6d4e0d7dee5b3c261 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 07:56:31 +0000 Subject: [PATCH 27/43] Improve code readability --- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7fe758c170b0c..11684949e7075 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -278,7 +278,7 @@ async def add_request_async( processed_inputs = await self.process_model_inputs_async( request_id=request_id, inputs=inputs, lora_request=lora_request) - return self._add_processed_request( + self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, params=params, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f4644a999745e..65606215b997f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -336,7 +336,7 @@ def get_tokenizer_for_seq(self, return self.get_tokenizer_group().get_lora_tokenizer( sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs): + def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: init_kwargs = dict( tokenizer_id=self.model_config.tokenizer, enable_lora=bool(self.lora_config), @@ -486,7 +486,7 @@ def add_request( inputs=inputs, lora_request=lora_request) - return self._add_processed_request( + self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, params=params, From 50343cb5fe52dde05a558f932938fd292fbb987c Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 08:34:17 +0000 Subject: [PATCH 28/43] Fix `multi_modal_data` being a required key --- vllm/engine/async_llm_engine.py | 2 +- vllm/inputs.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 11684949e7075..7869d2f1e3b3e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -239,7 +239,7 @@ async def step_async( async def process_model_inputs_async( self, - request_id: str, # pylint: disable=unused-argument + request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, ) -> LLMInputs: diff --git a/vllm/inputs.py b/vllm/inputs.py index 80011b6dd1d61..b6cd23f0c907f 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -1,6 +1,8 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, TypedDict, Union, cast, overload) +from typing_extensions import NotRequired + if TYPE_CHECKING: from vllm.sequence import MultiModalData @@ -70,7 +72,7 @@ class TextPrompt(TypedDict): prompt: str """The input text to be tokenized before passing to the model.""" - multi_modal_data: Optional["MultiModalData"] + multi_modal_data: NotRequired[Optional["MultiModalData"]] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -83,7 +85,7 @@ class TokensPrompt(TypedDict): prompt_token_ids: List[int] """A list of token IDs to pass to the model.""" - multi_modal_data: Optional["MultiModalData"] + multi_modal_data: NotRequired[Optional["MultiModalData"]] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -102,7 +104,7 @@ class TextTokensPrompt(TypedDict): """The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs.""" - multi_modal_data: Optional["MultiModalData"] + multi_modal_data: NotRequired[Optional["MultiModalData"]] """ Optional multi-modal data to pass to the model, if the model supports it. From 45aa42017eac55cd51377d5ac8b1b26552dd9e88 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 08:39:14 +0000 Subject: [PATCH 29/43] Fix index out of range error --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index eb48df5e4266a..d67a439603d5e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -546,7 +546,7 @@ def _run_engine(self, output_type: Type[_O], *, for output in step_outputs: # To improve performance, we only check the first result if is_first: - if not isinstance(outputs[0], output_type): + if not isinstance(output, output_type): raise TypeError( f"Expected output of type {output_type}, " f"but found type {type(output)}") From d4e2589be107f58276b55ef197bc3f2eb7df0cae Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 16:07:43 +0000 Subject: [PATCH 30/43] Use a flag to control whether to check output types --- vllm/engine/async_llm_engine.py | 20 +++++++------------- vllm/engine/llm_engine.py | 6 +++++- vllm/entrypoints/llm.py | 20 +++++++------------- 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7869d2f1e3b3e..a63a94628dc9e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,8 @@ import asyncio import time from functools import partial -from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, - Set, Tuple, Type, TypeVar, Union) +from typing import (TYPE_CHECKING, AsyncIterator, Callable, Dict, Iterable, + List, Optional, Set, Tuple, Type, TypeVar, Union) from transformers import PreTrainedTokenizer @@ -762,19 +762,13 @@ async def _process_request( ) try: - is_first = True - async for request_output in stream: - # To improve performance, we only check the first result - if is_first: - if not isinstance(request_output, output_type): - raise TypeError( - f"Expected output of type {output_type}, " - f"but found type {type(request_output)}") - - is_first = False + if ((TYPE_CHECKING or LLMEngine.VALIDATE_OUTPUT_TYPES) + and not isinstance(request_output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(request_output)}") - yield request_output # type: ignore + yield request_output except (Exception, asyncio.CancelledError) as e: self._abort(request_id) raise e diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 65606215b997f..76c006a05f307 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,5 +1,5 @@ import time -from typing import Iterable, List, Optional +from typing import ClassVar, Iterable, List, Optional from typing import Sequence as GenericSequence from typing import Type, Union @@ -86,6 +86,10 @@ class LLMEngine: log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. """ + + VALIDATE_OUTPUT_TYPES: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + tokenizer: Optional[BaseTokenizerGroup] def __init__( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d67a439603d5e..489dbe1266451 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,6 +1,6 @@ from contextlib import contextmanager -from typing import (ClassVar, List, Optional, Sequence, Type, TypeVar, Union, - cast, overload) +from typing import (TYPE_CHECKING, ClassVar, List, Optional, Sequence, Type, + TypeVar, Union, cast, overload) from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -541,20 +541,14 @@ def _run_engine(self, output_type: Type[_O], *, total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() - is_first = True - for output in step_outputs: - # To improve performance, we only check the first result - if is_first: - if not isinstance(output, output_type): - raise TypeError( - f"Expected output of type {output_type}, " - f"but found type {type(output)}") - - is_first = False + if ((TYPE_CHECKING or LLMEngine.VALIDATE_OUTPUT_TYPES) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") if output.finished: - outputs.append(output) # type: ignore + outputs.append(output) if use_tqdm: if isinstance(output, RequestOutput): # Calculate tokens only for RequestOutput From c07b5798bd4c2c25e93f666cc4ba84468a1429f8 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 22:05:32 +0000 Subject: [PATCH 31/43] Simplify flags --- tests/entrypoints/test_llm_encode.py | 18 +++++++++--------- tests/entrypoints/test_llm_generate.py | 18 +++++++++--------- vllm/engine/async_llm_engine.py | 4 +++- vllm/engine/llm_engine.py | 10 ++++++++++ vllm/entrypoints/llm.py | 12 +++++++----- 5 files changed, 38 insertions(+), 24 deletions(-) diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index 872707b54e3f7..39fc7c2e0f0b9 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -30,15 +30,15 @@ @pytest.fixture(scope="module") def llm(): - with LLM.deprecate_legacy_ctx(): - # pytest caches the fixture so we use weakref.proxy to - # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True) - + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True) + + with llm.deprecate_legacy_api(): yield weakref.proxy(llm) del llm diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 37d1ea7e8745b..44f5feb1aa0a2 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -28,15 +28,15 @@ @pytest.fixture(scope="module") def llm(): - with LLM.deprecate_legacy_ctx(): - # pytest caches the fixture so we use weakref.proxy to - # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=4096, - tensor_parallel_size=1, - gpu_memory_utilization=0.10, - enforce_eager=True) - + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) + + with llm.deprecate_legacy_api(): yield weakref.proxy(llm) del llm diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a63a94628dc9e..8e93394bd0d9f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -761,9 +761,11 @@ async def _process_request( lora_request=lora_request, ) + validate_output_types = LLMEngine.VALIDATE_OUTPUT_TYPES + try: async for request_output in stream: - if ((TYPE_CHECKING or LLMEngine.VALIDATE_OUTPUT_TYPES) + if ((TYPE_CHECKING or validate_output_types) and not isinstance(request_output, output_type)): raise TypeError(f"Expected output of type {output_type}, " f"but found type {type(request_output)}") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 76c006a05f307..ced5f98820b3e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,4 +1,5 @@ import time +from contextlib import contextmanager from typing import ClassVar, Iterable, List, Optional from typing import Sequence as GenericSequence from typing import Type, Union @@ -90,6 +91,15 @@ class LLMEngine: VALIDATE_OUTPUT_TYPES: ClassVar[bool] = False """A flag to toggle whether to validate the type of request output.""" + @classmethod + @contextmanager + def validate_output_types(cls): + cls.VALIDATE_OUTPUT_TYPES = True + + yield + + cls.VALIDATE_OUTPUT_TYPES = False + tokenizer: Optional[BaseTokenizerGroup] def __init__( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 489dbe1266451..c9d2d6eff0497 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -91,14 +91,14 @@ class LLM: DEPRECATE_LEGACY: ClassVar[bool] = False """A flag to toggle whether to deprecate the legacy generate/encode API.""" - @staticmethod + @classmethod @contextmanager - def deprecate_legacy_ctx(): - LLM.DEPRECATE_LEGACY = True + def deprecate_legacy_api(cls): + cls.DEPRECATE_LEGACY = True yield - LLM.DEPRECATE_LEGACY = False + cls.DEPRECATE_LEGACY = False def __init__( self, @@ -541,8 +541,10 @@ def _run_engine(self, output_type: Type[_O], *, total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() + validate_output_types = LLMEngine.VALIDATE_OUTPUT_TYPES + for output in step_outputs: - if ((TYPE_CHECKING or LLMEngine.VALIDATE_OUTPUT_TYPES) + if ((TYPE_CHECKING or validate_output_types) and not isinstance(output, output_type)): raise TypeError(f"Expected output of type {output_type}, " f"but found type {type(output)}") From 9d56eb0667b107af30f89604fb6ff5818b451f49 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 22:31:18 +0000 Subject: [PATCH 32/43] Move output validation to a more appropriate location --- tests/entrypoints/test_llm_encode.py | 8 ++-- tests/entrypoints/test_llm_generate.py | 8 ++-- tests/lora/test_long_context.py | 4 +- tests/samplers/test_logits_processor.py | 4 +- tests/samplers/test_seeded_generate.py | 4 +- vllm/engine/async_llm_engine.py | 23 +++-------- vllm/engine/llm_engine.py | 52 ++++++++++++++++++++++--- vllm/entrypoints/llm.py | 25 +++++------- 8 files changed, 74 insertions(+), 54 deletions(-) diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index 39fc7c2e0f0b9..7c3fbe43a8384 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -33,10 +33,10 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True) + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 44f5feb1aa0a2..a00fff91a310e 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -31,10 +31,10 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=4096, - tensor_parallel_size=1, - gpu_memory_utilization=0.10, - enforce_eager=True) + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 3dd9b98ed911b..15189f421a539 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -5,7 +5,7 @@ import pytest import vllm -from vllm import RequestOutput, SamplingParams +from vllm import SamplingParams from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora from vllm.lora.request import LoRARequest from vllm.model_executor.layers.rotary_embedding import ( @@ -100,7 +100,7 @@ def batched_generate( # Add requests to the engine and run the engine for request_data in requests_data: llm._add_request(**request_data) - outputs = llm._run_engine(RequestOutput, use_tqdm=True) + outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 1b63c1dab98d2..0ccbabfff6403 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -1,7 +1,7 @@ import pytest import torch -from vllm import RequestOutput, SamplingParams +from vllm import SamplingParams MODELS = ["facebook/opt-125m"] @@ -54,6 +54,6 @@ def pick_vllm(token_ids, logits): params=SamplingParams(max_tokens=max_tokens), ) - outputs = vllm_model.model._run_engine(RequestOutput, use_tqdm=False) + outputs = vllm_model.model._run_engine(use_tqdm=False) assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index fca2b0e05c335..fef5ff3fb9e8e 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -8,7 +8,7 @@ import pytest -from vllm import RequestOutput, SamplingParams +from vllm import SamplingParams from vllm.model_executor.utils import set_random_seed MODEL = "facebook/opt-125m" @@ -59,7 +59,7 @@ def test_random_sample_with_seed( ): llm._add_request(prompt, params=params) - results = llm._run_engine(RequestOutput, use_tqdm=False) + results = llm._run_engine(use_tqdm=False) all_outputs = [[out.token_ids for out in output.outputs] for output in results] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8e93394bd0d9f..53d8f4421ad72 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,8 @@ import asyncio import time from functools import partial -from typing import (TYPE_CHECKING, AsyncIterator, Callable, Dict, Iterable, - List, Optional, Set, Tuple, Type, TypeVar, Union) +from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, + Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer @@ -290,9 +290,6 @@ async def check_health_async(self) -> None: self.model_executor.check_health() -_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) - - class AsyncLLMEngine: """An asynchronous wrapper for :class:`LLMEngine`. @@ -660,10 +657,9 @@ async def generate( request_id, inputs, sampling_params, - output_type=RequestOutput, lora_request=lora_request, ): - yield output + yield LLMEngine.validate_output(output, RequestOutput) async def encode( self, @@ -735,10 +731,9 @@ async def encode( request_id, inputs, pooling_params, - output_type=EmbeddingRequestOutput, lora_request=lora_request, ): - yield output + yield LLMEngine.validate_output(output, EmbeddingRequestOutput) async def _process_request( self, @@ -746,9 +741,8 @@ async def _process_request( inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], *, - output_type: Type[_O], lora_request: Optional[LoRARequest] = None, - ) -> AsyncIterator[_O]: + ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" arrival_time = time.time() @@ -761,15 +755,8 @@ async def _process_request( lora_request=lora_request, ) - validate_output_types = LLMEngine.VALIDATE_OUTPUT_TYPES - try: async for request_output in stream: - if ((TYPE_CHECKING or validate_output_types) - and not isinstance(request_output, output_type)): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(request_output)}") - yield request_output except (Exception, asyncio.CancelledError) as e: self._abort(request_id) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ced5f98820b3e..7520e7eb1c4cc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,8 +1,8 @@ import time from contextlib import contextmanager -from typing import ClassVar, Iterable, List, Optional +from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional from typing import Sequence as GenericSequence -from typing import Type, Union +from typing import Type, TypeVar, Union from transformers import GenerationConfig, PreTrainedTokenizer @@ -54,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig): return {} +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -88,17 +91,54 @@ class LLMEngine: usage_context: Specified entry point, used for usage info collection. """ - VALIDATE_OUTPUT_TYPES: ClassVar[bool] = False + DO_VALIDATE_OUTPUT: ClassVar[bool] = False """A flag to toggle whether to validate the type of request output.""" @classmethod @contextmanager - def validate_output_types(cls): - cls.VALIDATE_OUTPUT_TYPES = True + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True yield - cls.VALIDATE_OUTPUT_TYPES = False + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return output + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ tokenizer: Optional[BaseTokenizerGroup] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c9d2d6eff0497..4efc6b27f6f49 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,6 +1,5 @@ from contextlib import contextmanager -from typing import (TYPE_CHECKING, ClassVar, List, Optional, Sequence, Type, - TypeVar, Union, cast, overload) +from typing import ClassVar, List, Optional, Sequence, Union, cast, overload from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -20,8 +19,6 @@ logger = init_logger(__name__) -_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) - class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -293,7 +290,8 @@ def generate( lora_request=lora_request, ) - return self._run_engine(RequestOutput, use_tqdm=use_tqdm) + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, RequestOutput) @overload # LEGACY: single (prompt + optional token ids) def encode( @@ -430,7 +428,8 @@ def encode( lora_request=lora_request, ) - return self._run_engine(EmbeddingRequestOutput, use_tqdm=use_tqdm) + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) # LEGACY def _convert_v1_inputs( @@ -525,8 +524,9 @@ def _add_request( params, lora_request=lora_request) - def _run_engine(self, output_type: Type[_O], *, - use_tqdm: bool) -> List[_O]: + def _run_engine( + self, *, use_tqdm: bool + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -537,18 +537,11 @@ def _run_engine(self, output_type: Type[_O], *, postfix=f"Generation Speed: {0:.2f} toks/s", ) # Run the engine. - outputs: List[_O] = [] + outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() - validate_output_types = LLMEngine.VALIDATE_OUTPUT_TYPES - for output in step_outputs: - if ((TYPE_CHECKING or validate_output_types) - and not isinstance(output, output_type)): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(output)}") - if output.finished: outputs.append(output) if use_tqdm: From bc05031fe237193f048189fa9be33c0067ed40ba Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 22:37:39 +0000 Subject: [PATCH 33/43] Add message to deprecation notice --- tests/test_utils.py | 11 ++++++++++- vllm/entrypoints/llm.py | 20 ++++++++++++-------- vllm/utils.py | 14 ++++++++------ 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 988dc5ba2bf29..df993d2665b64 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -31,7 +31,7 @@ def dummy(*, old_arg: object = None, new_arg: object = None): dummy(new_arg=1) -def test_deprecate_kwargs_func(): +def test_deprecate_kwargs_dynamic(): is_deprecated = True @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) @@ -51,3 +51,12 @@ def dummy(*, old_arg: object = None, new_arg: object = None): with error_on_warning(): dummy(new_arg=1) + + +def test_deprecate_kwargs_additional_message(): + @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="abcd"): + dummy(old_arg=1) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 4efc6b27f6f49..05aea9aac6456 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -233,10 +233,12 @@ def generate( ) -> List[RequestOutput]: ... - @deprecate_kwargs('prompts', - 'prompt_token_ids', - 'multi_modal_data', - is_deprecated=lambda: LLM.DEPRECATE_LEGACY) + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") def generate( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], @@ -372,10 +374,12 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @deprecate_kwargs('prompts', - 'prompt_token_ids', - 'multi_modal_data', - is_deprecated=lambda: LLM.DEPRECATE_LEGACY) + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") def encode( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], diff --git a/vllm/utils.py b/vllm/utils.py index 4480c6c6960de..979e15568a0dc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -664,8 +664,8 @@ def identity(value: T) -> T: def deprecate_kwargs( *kws: str, - is_deprecated: Union[bool, Callable[[], - bool]] = True) -> Callable[[F], F]: + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None) -> Callable[[F], F]: deprecated_kws = set(kws) if not callable(is_deprecated): @@ -678,11 +678,13 @@ def inner(*args, **kwargs): if is_deprecated(): deprecated_kwargs = kwargs.keys() & deprecated_kws if deprecated_kwargs: + msg = (f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update.") + if additional_message is not None: + msg += f" {additional_message}" + warnings.warn( - DeprecationWarning( - f"The keyword arguments {deprecated_kwargs} are " - "deprecated and will be removed in a future update." - ), + DeprecationWarning(msg), stacklevel=3, # The inner function takes up one level ) From 95d41303edd2b086f88afbc78939d03e4fe995c6 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 22:40:22 +0000 Subject: [PATCH 34/43] Apply formatter --- tests/test_utils.py | 1 + vllm/entrypoints/llm.py | 4 ++-- vllm/utils.py | 5 +++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index df993d2665b64..54dc5c6f5bfba 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -54,6 +54,7 @@ def dummy(*, old_arg: object = None, new_arg: object = None): def test_deprecate_kwargs_additional_message(): + @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") def dummy(*, old_arg: object = None, new_arg: object = None): pass diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 05aea9aac6456..53091cdc6ee42 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -238,7 +238,7 @@ def generate( "multi_modal_data", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, additional_message="Please use the 'inputs' parameter " - "instead.") + "instead.") def generate( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], @@ -379,7 +379,7 @@ def encode( "multi_modal_data", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, additional_message="Please use the 'inputs' parameter " - "instead.") + "instead.") def encode( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], diff --git a/vllm/utils.py b/vllm/utils.py index 979e15568a0dc..1d99f0be8d3be 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -678,8 +678,9 @@ def inner(*args, **kwargs): if is_deprecated(): deprecated_kwargs = kwargs.keys() & deprecated_kws if deprecated_kwargs: - msg = (f"The keyword arguments {deprecated_kwargs} are " - "deprecated and will be removed in a future update.") + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update.") if additional_message is not None: msg += f" {additional_message}" From cc84f65c0ea3b45b77236b836bb4a422eac28066 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 24 May 2024 00:50:30 +0000 Subject: [PATCH 35/43] Remove unused parameter in `_validate_and_add_requests` and fix test --- tests/lora/test_long_context.py | 8 +++----- vllm/entrypoints/llm.py | 3 --- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 15189f421a539..4361e5452cdff 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -86,20 +86,18 @@ def generate( def batched_generate( - llm, + llm: vllm.LLM, inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], ): for input in inputs: prompt, sampling_param, lora_req = input - requests_data = llm._validate_and_prepare_requests( + # Add requests to the engine and run the engine + llm._validate_and_add_requests( prompt, sampling_param, lora_request=lora_req, ) - # Add requests to the engine and run the engine - for request_data in requests_data: - llm._add_request(**request_data) outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 53091cdc6ee42..40ce9b1a992e5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -288,7 +288,6 @@ def generate( self._validate_and_add_requests( inputs=inputs, params=sampling_params, - use_tqdm=use_tqdm, lora_request=lora_request, ) @@ -428,7 +427,6 @@ def encode( self._validate_and_add_requests( inputs=inputs, params=pooling_params, - use_tqdm=use_tqdm, lora_request=lora_request, ) @@ -495,7 +493,6 @@ def _validate_and_add_requests( inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], - use_tqdm: bool, lora_request: Optional[LoRARequest], ) -> None: if isinstance(inputs, (str, dict)): From 6c5d4a6cd70f576a809db6eab95f11a09f92db4f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 25 May 2024 02:37:25 +0000 Subject: [PATCH 36/43] Simplify code --- vllm/engine/llm_engine.py | 3 ++- vllm/entrypoints/llm.py | 26 ++++++++++++-------------- vllm/inputs.py | 6 +++--- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7520e7eb1c4cc..7ce8021a205ee 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -424,11 +424,12 @@ def _add_processed_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - eos_token_id = None + if self.tokenizer: eos_token_id = self.tokenizer.get_lora_tokenizer( lora_request).eos_token_id else: + eos_token_id = None logger.warning("Use None for EOS token id because tokenizer is " "not initialized") seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 40ce9b1a992e5..9759d05577796 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,7 +6,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.inputs import (PromptInputs, PromptStrictInputs, +from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt, + TextTokensPrompt, TokensPrompt, parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -467,25 +468,22 @@ def _convert_v1_inputs( for i in range(num_requests): if prompts is not None: if prompt_token_ids is not None: - inputs.append({ - "prompt": prompts[i], - "prompt_token_ids": prompt_token_ids[i], - "multi_modal_data": multi_modal_data, - }) + item = TextTokensPrompt( + prompt=prompts[i], + prompt_token_ids=prompt_token_ids[i]) else: - inputs.append({ - "prompt": prompts[i], - "multi_modal_data": multi_modal_data, - }) + item = TextPrompt(prompt=prompts[i]) else: if prompt_token_ids is not None: - inputs.append({ - "prompt_token_ids": prompt_token_ids[i], - "multi_modal_data": multi_modal_data, - }) + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) else: raise AssertionError + if multi_modal_data is not None: + item["multi_modal_data"] = multi_modal_data + + inputs.append(item) + return inputs def _validate_and_add_requests( diff --git a/vllm/inputs.py b/vllm/inputs.py index b6cd23f0c907f..f5d99b1b66b70 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -72,7 +72,7 @@ class TextPrompt(TypedDict): prompt: str """The input text to be tokenized before passing to the model.""" - multi_modal_data: NotRequired[Optional["MultiModalData"]] + multi_modal_data: NotRequired["MultiModalData"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -85,7 +85,7 @@ class TokensPrompt(TypedDict): prompt_token_ids: List[int] """A list of token IDs to pass to the model.""" - multi_modal_data: NotRequired[Optional["MultiModalData"]] + multi_modal_data: NotRequired["MultiModalData"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -104,7 +104,7 @@ class TextTokensPrompt(TypedDict): """The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs.""" - multi_modal_data: NotRequired[Optional["MultiModalData"]] + multi_modal_data: NotRequired["MultiModalData"] """ Optional multi-modal data to pass to the model, if the model supports it. From fd2da125ea7bde3a83cbb65d25460c5abeff5b83 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 25 May 2024 02:42:08 +0000 Subject: [PATCH 37/43] Move attribute assignment outside `_init_tokenizer` --- vllm/engine/llm_engine.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7ce8021a205ee..0be3d3140b6b1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -207,11 +207,11 @@ def __init__( self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: - tokenizer = self._init_tokenizer() + self.tokenizer = tokenizer = self._init_tokenizer() self.detokenizer = Detokenizer(tokenizer) else: - self.detokenizer = None self.tokenizer = None + self.detokenizer = None self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( @@ -376,7 +376,9 @@ def __del__(self): MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because " "skip_tokenizer_init is True") - def get_tokenizer_group(self, fail_msg: str = MISSING_TOKENIZER_GROUP_MSG): + def get_tokenizer_group( + self, + fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup: if self.tokenizer is None: raise ValueError(fail_msg) @@ -400,10 +402,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer = get_tokenizer_group( - self.parallel_config.tokenizer_pool_config, **init_kwargs) - return self.tokenizer + return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, + **init_kwargs) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) From d78de94c1c8be82f19213a0d68860925e40edf75 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 25 May 2024 03:12:31 +0000 Subject: [PATCH 38/43] Only emit warning once --- vllm/engine/llm_engine.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0be3d3140b6b1..a0898562d4ccf 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -213,6 +213,8 @@ def __init__( self.tokenizer = None self.detokenizer = None + self._eos_warn_count = 0 + self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( model_config) @@ -414,6 +416,18 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) + def _get_eos_token_id( + self, lora_request: Optional[LoRARequest]) -> Optional[int]: + if self.tokenizer: + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + else: + if self._eos_warn_count == 0: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + + self._eos_warn_count += 1 + return None + def _add_processed_request( self, request_id: str, @@ -425,14 +439,8 @@ def _add_processed_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) + eos_token_id = self._get_eos_token_id(lora_request) - if self.tokenizer: - eos_token_id = self.tokenizer.get_lora_tokenizer( - lora_request).eos_token_id - else: - eos_token_id = None - logger.warning("Use None for EOS token id because tokenizer is " - "not initialized") seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, lora_request) From 8a868299fab4f9505f3a7568114b54a78513abf7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 25 May 2024 03:16:24 +0000 Subject: [PATCH 39/43] Simplify assignment expression --- vllm/engine/llm_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a0898562d4ccf..3c50033e46e30 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -207,8 +207,8 @@ def __init__( self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: - self.tokenizer = tokenizer = self._init_tokenizer() - self.detokenizer = Detokenizer(tokenizer) + self.tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) else: self.tokenizer = None self.detokenizer = None From 731ac0e2cf03006ef653fe3079a9812644a02825 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 25 May 2024 03:19:11 +0000 Subject: [PATCH 40/43] Place special case at the start --- vllm/engine/llm_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3c50033e46e30..cbc9bf741476a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -418,9 +418,7 @@ def _verify_args(self) -> None: def _get_eos_token_id( self, lora_request: Optional[LoRARequest]) -> Optional[int]: - if self.tokenizer: - return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id - else: + if self.tokenizer is None: if self._eos_warn_count == 0: logger.warning("Using None for EOS token id because tokenizer " "is not initialized") @@ -428,6 +426,8 @@ def _get_eos_token_id( self._eos_warn_count += 1 return None + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + def _add_processed_request( self, request_id: str, From 2d1a0bccf100b6e9ca6912b69caac72b8ba81ded Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 25 May 2024 12:24:44 -0700 Subject: [PATCH 41/43] move API reference to under developer doc --- docs/source/{ => dev}/offline_inference/llm.rst | 0 .../{ => dev}/offline_inference/llm_inputs.rst | 0 docs/source/dev/offline_inference/offline_index.rst | 8 ++++++++ .../{offline_inference => dev}/sampling_params.rst | 0 docs/source/index.rst | 12 +++--------- 5 files changed, 11 insertions(+), 9 deletions(-) rename docs/source/{ => dev}/offline_inference/llm.rst (100%) rename docs/source/{ => dev}/offline_inference/llm_inputs.rst (100%) create mode 100644 docs/source/dev/offline_inference/offline_index.rst rename docs/source/{offline_inference => dev}/sampling_params.rst (100%) diff --git a/docs/source/offline_inference/llm.rst b/docs/source/dev/offline_inference/llm.rst similarity index 100% rename from docs/source/offline_inference/llm.rst rename to docs/source/dev/offline_inference/llm.rst diff --git a/docs/source/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst similarity index 100% rename from docs/source/offline_inference/llm_inputs.rst rename to docs/source/dev/offline_inference/llm_inputs.rst diff --git a/docs/source/dev/offline_inference/offline_index.rst b/docs/source/dev/offline_inference/offline_index.rst new file mode 100644 index 0000000000000..27dfb0e9df90e --- /dev/null +++ b/docs/source/dev/offline_inference/offline_index.rst @@ -0,0 +1,8 @@ +Offline Inference +================================= + +.. toctree:: + :maxdepth: 1 + + llm + llm_inputs diff --git a/docs/source/offline_inference/sampling_params.rst b/docs/source/dev/sampling_params.rst similarity index 100% rename from docs/source/offline_inference/sampling_params.rst rename to docs/source/dev/sampling_params.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 6383680f2b512..acf02c1c22251 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -68,14 +68,6 @@ Documentation getting_started/quickstart getting_started/examples/examples_index -.. toctree:: - :maxdepth: 1 - :caption: Offline Inference - - offline_inference/llm - offline_inference/llm_inputs - offline_inference/sampling_params - .. toctree:: :maxdepth: 1 :caption: Serving @@ -109,7 +101,9 @@ Documentation .. toctree:: :maxdepth: 2 :caption: Developer Documentation - + + dev/sampling_params + dev/offline_inference/offline_index dev/engine/engine_index dev/kernel/paged_attention dev/dockerfile/dockerfile From 7b8ce2c271c2f5d2d726fc084dbe8942be0d5199 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 26 May 2024 01:32:29 +0000 Subject: [PATCH 42/43] Fix links in docs --- docs/source/serving/openai_compatible_server.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a775c6addf1d9..15a8761eb5738 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -48,7 +48,7 @@ completion = client.chat.completions.create( ``` ### Extra Parameters for Chat API -The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python @@ -65,7 +65,7 @@ The following extra parameters are supported: ``` ### Extra Parameters for Completions API -The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python From fff21a1ec0e7520f5b4f46dcb49c36e305e80894 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 26 May 2024 01:34:08 +0000 Subject: [PATCH 43/43] Remove unnecessary code to avoid repeated warning --- vllm/engine/llm_engine.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cbc9bf741476a..0dd42a1867c46 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -213,8 +213,6 @@ def __init__( self.tokenizer = None self.detokenizer = None - self._eos_warn_count = 0 - self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( model_config) @@ -419,11 +417,8 @@ def _verify_args(self) -> None: def _get_eos_token_id( self, lora_request: Optional[LoRARequest]) -> Optional[int]: if self.tokenizer is None: - if self._eos_warn_count == 0: - logger.warning("Using None for EOS token id because tokenizer " - "is not initialized") - - self._eos_warn_count += 1 + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") return None return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id