From 168cab6bbfb733f97defc8c1aa13df90c5319f19 Mon Sep 17 00:00:00 2001 From: Brendan Wong <35351983+LunrEclipse@users.noreply.github.com> Date: Sat, 5 Oct 2024 23:39:03 -0700 Subject: [PATCH] [Frontend] API support for beam search (#9087) Co-authored-by: youkaichao --- benchmarks/benchmark_throughput.py | 12 +- tests/conftest.py | 5 +- tests/entrypoints/openai/test_completion.py | 43 +++---- vllm/engine/async_llm_engine.py | 107 +++++++++++++++++- vllm/entrypoints/llm.py | 20 ++-- vllm/entrypoints/logger.py | 5 +- vllm/entrypoints/openai/protocol.py | 36 +++++- vllm/entrypoints/openai/serving_chat.py | 43 +++++-- vllm/entrypoints/openai/serving_completion.py | 46 ++++++-- vllm/entrypoints/openai/serving_engine.py | 5 +- vllm/sampling_params.py | 12 ++ vllm/utils.py | 9 ++ 12 files changed, 275 insertions(+), 68 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 68b401d5bbbb7..c6bc607ff6b8e 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -15,6 +15,7 @@ from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.sampling_params import BeamSearchParams from vllm.utils import FlexibleArgumentParser, merge_async_iterators @@ -145,10 +146,13 @@ def run_vllm( for prompt, input_len, _output_len in requests: assert _output_len == output_len start = time.perf_counter() - llm.beam_search(prompts, - beam_width=n, - max_tokens=output_len, - ignore_eos=True) + llm.beam_search( + prompts, + BeamSearchParams( + beam_width=n, + max_tokens=output_len, + ignore_eos=True, + )) end = time.perf_counter() return end - start diff --git a/tests/conftest.py b/tests/conftest.py index 177b8a0640278..5de3f1f2a2b90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,6 +35,7 @@ to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, identity, is_cpu) @@ -812,7 +813,9 @@ def generate_beam_search_new( beam_width: int, max_tokens: int, ) -> List[Tuple[List[List[int]], List[str]]]: - outputs = self.model.beam_search(prompts, beam_width, max_tokens) + outputs = self.model.beam_search( + prompts, + BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) returned_outputs = [] for output in outputs: token_ids = [x.tokens for x in output.sequences] diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index d77cd57f12471..61da5513cb130 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -495,25 +495,30 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): assert len(batch.choices) == 2 assert batch.choices[0].text == batch.choices[1].text - # test n = 2 - batch = await client.completions.create( - model=model_name, - prompt=prompts, - n=2, - max_tokens=5, - temperature=0.0, - extra_body=dict( - # NOTE: this has to be true for n > 1 in vLLM, but not necessary - # for official client. - use_beam_search=True), - ) - assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" + try: + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but + # not necessary for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + except BadRequestError as e: + # the only allowed exception is when beam search is not supported + # in the default mqllmengine + assert "--disable-frontend-multiprocessing" in str(e) # test streaming batch = await client.completions.create( diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e7d770c976319..a0aaa9e6c372a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -14,23 +14,26 @@ from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.metrics_types import StatLoggerBase +from vllm.entrypoints.llm import BeamSearchSequence from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType +from vllm.inputs import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, + RequestOutput) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import deprecate_kwargs, weak_bind +from vllm.utils import (collect_from_async_generator, deprecate_kwargs, + random_uuid, weak_bind) logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -1036,6 +1039,102 @@ async def generate( ): yield LLMEngine.validate_output(output, RequestOutput) + async def beam_search( + self, + prompt: Union[PromptType, List[int]], + request_id: str, + params: BeamSearchParams, + ) -> AsyncGenerator[RequestOutput, None]: + + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + + tokenizer = await self.get_tokenizer() + tokenizedPrompt = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + tokenizedLength = len(tokenizedPrompt) + + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) + all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] + completed = [] + + for _ in range(max_tokens): + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + tasks = [] + + request_id = f"beam_search-{random_uuid()}" + for i, individual_prompt in enumerate(prompts_batch): + request_id_item = f"{request_id}-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.generate(individual_prompt, beam_search_params, + request_id_item))) + tasks.append(task) + + output = await asyncio.gather(*tasks) + + output = [x[0] for x in output] + + logger.info(output) + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + completed.append(new_beam) + else: + new_beams.append(new_beam) + + sorted_beams = sorted(new_beams, + key=lambda x: x.cum_logprob, + reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, + key=lambda x: x.cum_logprob, + reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) + + beam_search_output = RequestOutput( + request_id=request_id, + prompt=prompt, + outputs=[ + CompletionOutput( + text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens, + index=i, + logprobs=beam.cum_logprob, + ) for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=tokenizedPrompt, + prompt_logprobs=None) + + yield LLMEngine.validate_output(beam_search_output, RequestOutput) + async def encode( self, prompt: PromptType, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 98d6df944da67..f50ed7288f131 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -22,8 +22,8 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind, - SamplingParams) +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -394,10 +394,7 @@ def generate( def beam_search( self, prompts: List[Union[str, List[int]]], - beam_width: int, - max_tokens: int, - ignore_eos: bool = False, - temperature: float = 0.0, + params: BeamSearchParams, ) -> List[BeamSearchOutput]: """ Generate sequences using beam search. @@ -405,14 +402,17 @@ def beam_search( Args: prompts: A list of prompts. Each prompt can be a string or a list of token IDs. - beam_width: The number of beams to keep at each step. - max_tokens: The max number of tokens to generate for each prompt. - temperature: The temperature to use for generation. - + params: The beam search parameters. + TODO: how does beam search work together with length penalty, frequency penalty, and stopping criteria, etc.? """ + beam_width = params.beam_width + max_tokens = params.max_tokens + temperature = params.temperature + ignore_eos = params.ignore_eos + tokenizer = self.get_tokenizer() # generate 2 * beam_width candidates at each step # following the huggingface transformers implementation diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 091896e1c7a69..584ee0d9e1c54 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -4,7 +4,7 @@ from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import BeamSearchParams, SamplingParams logger = init_logger(__name__) @@ -21,7 +21,8 @@ def log_inputs( request_id: str, prompt: Optional[str], prompt_token_ids: Optional[List[int]], - params: Optional[Union[SamplingParams, PoolingParams]], + params: Optional[Union[SamplingParams, PoolingParams, + BeamSearchParams]], lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7c5bd5b091b65..f0aaf3733869d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -11,8 +11,8 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.pooling_params import PoolingParams -from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind, - SamplingParams) +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) from vllm.sequence import Logprob from vllm.utils import random_uuid @@ -288,6 +288,22 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params + def to_beam_search_params(self, + default_max_tokens: int) -> BeamSearchParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + + n = self.n if self.n is not None else 1 + temperature = self.temperature if self.temperature is not None else 0.0 + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + ) + def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: @@ -567,6 +583,22 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params + def to_beam_search_params(self, + default_max_tokens: int) -> BeamSearchParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + + n = self.n if self.n is not None else 1 + temperature = self.temperature if self.temperature is not None else 0.0 + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + ) + def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ce529f6f0ff58..fc6611a754ae5 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,6 +9,7 @@ from fastapi import Request from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, apply_hf_chat_template, @@ -33,6 +34,7 @@ from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) @@ -203,9 +205,15 @@ async def create_chat_completion( assert prompt_inputs is not None - sampling_params = request.to_sampling_params( - default_max_tokens=self.max_model_len - - len(prompt_inputs["prompt_token_ids"])) + sampling_params: Union[SamplingParams, BeamSearchParams] + default_max_tokens = self.max_model_len - len( + prompt_inputs["prompt_token_ids"]) + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + default_max_tokens) + else: + sampling_params = request.to_sampling_params( + default_max_tokens) self._log_inputs(request_id, prompt_inputs, @@ -227,15 +235,26 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.engine_client.generate( - engine_inputs, - sampling_params, - request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=request.priority, - ) + if isinstance(sampling_params, BeamSearchParams): + if not isinstance(self.engine_client, AsyncLLMEngine): + raise ValueError( + "Beam search in the API server is only supported with" + " AsyncLLMEngine. please add " + "`--disable-frontend-multiprocessing` to " + "use beam search.") + result_generator = self.engine_client.beam_search( + engine_inputs['prompt_token_ids'], request_id, + sampling_params) + else: + result_generator = self.engine_client.generate( + engine_inputs, + sampling_params, + request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=request.priority, + ) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 59e69121deb9e..bf9e9850797a6 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,6 +8,7 @@ from fastapi import Request from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block @@ -28,6 +29,7 @@ PromptAdapterPath) from vllm.logger import init_logger from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) @@ -120,9 +122,15 @@ async def create_completion( )) for i, prompt_inputs in enumerate(prompts): - sampling_params = request.to_sampling_params( - default_max_tokens=self.max_model_len - - len(prompt_inputs["prompt_token_ids"])) + sampling_params: Union[SamplingParams, BeamSearchParams] + default_max_tokens = self.max_model_len - len( + prompt_inputs["prompt_token_ids"]) + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + default_max_tokens) + else: + sampling_params = request.to_sampling_params( + default_max_tokens) request_id_item = f"{request_id}-{i}" @@ -141,15 +149,29 @@ async def create_completion( raw_request.headers): log_tracing_disabled_warning() - generator = self.engine_client.generate( - {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, - sampling_params, - request_id_item, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - trace_headers=trace_headers, - priority=request.priority, - ) + if isinstance(sampling_params, BeamSearchParams): + if not isinstance(self.engine_client, AsyncLLMEngine): + raise ValueError( + "Beam search in the API server is only supported" + " with AsyncLLMEngine. please add " + "`--disable-frontend-multiprocessing` to " + "use beam search.") + generator = self.engine_client.beam_search( + prompt_inputs["prompt_token_ids"], request_id_item, + sampling_params) + else: + generator = self.engine_client.generate( + { + "prompt_token_ids": + prompt_inputs["prompt_token_ids"] + }, + sampling_params, + request_id_item, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=request.priority, + ) generators.append(generator) except ValueError as e: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 1a0669d8d12c5..e6d2ab93d3363 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -29,7 +29,7 @@ from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import AtomicCounter @@ -371,7 +371,8 @@ def _log_inputs( self, request_id: str, inputs: Union[str, List[int], TextTokensPrompt], - params: Optional[Union[SamplingParams, PoolingParams]], + params: Optional[Union[SamplingParams, PoolingParams, + BeamSearchParams]], lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 83f76410882de..adf0d2dd6ca2f 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -530,3 +530,15 @@ def __repr__(self) -> str: f"{self.spaces_between_special_tokens}, " f"truncate_prompt_tokens={self.truncate_prompt_tokens}), " f"guided_decoding={self.guided_decoding}") + + +class BeamSearchParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): # type: ignore[call-arg] + """Beam search parameters for text generation.""" + beam_width: int + max_tokens: int + ignore_eos: bool = False + temperature: float = 0.0 diff --git a/vllm/utils.py b/vllm/utils.py index 197584867d8b0..e44365fa24990 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -504,6 +504,15 @@ async def merge_async_iterators( await it.aclose() +async def collect_from_async_generator( + iterator: AsyncGenerator[T, None]) -> List[T]: + """Collect all items from an async generator into a list.""" + items = [] + async for item in iterator: + items.append(item) + return items + + def get_ip() -> str: host_ip = envs.VLLM_HOST_IP if host_ip: