Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] merge beam search implementations #9296

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 6 additions & 104 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,31 @@
from weakref import ReferenceType

import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TokensPrompt
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
random_uuid, weak_bind)
from vllm.utils import deprecate_kwargs, weak_bind

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down Expand Up @@ -541,7 +539,7 @@ async def build_guided_decoding_logits_processor_async(
return sampling_params


class AsyncLLMEngine:
class AsyncLLMEngine(EngineClient):
"""An asynchronous wrapper for :class:`LLMEngine`.

This class is used to wrap the :class:`LLMEngine` class to make it
Expand Down Expand Up @@ -1039,102 +1037,6 @@ async def generate(
):
yield LLMEngine.validate_output(output, RequestOutput)

async def beam_search(
self,
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:

beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty

tokenizer = await self.get_tokenizer()
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []

for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]

tasks = []

request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)

output = await asyncio.gather(*tasks)

output = [x[0] for x in output]

logger.info(output)

new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]

if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)

sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]

completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]

for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])

beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)

yield LLMEngine.validate_output(beam_search_output, RequestOutput)

async def encode(
self,
prompt: PromptType,
Expand Down
122 changes: 9 additions & 113 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, overload)
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
Union, overload)

import cloudpickle
import zmq
Expand All @@ -12,7 +12,6 @@
from zmq.asyncio import Socket

from vllm import PoolingParams
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
Expand All @@ -26,18 +25,17 @@
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType, TokensPrompt
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
random_uuid)
from vllm.utils import deprecate_kwargs

logger = init_logger(__name__)

Expand All @@ -53,7 +51,7 @@ class MQClientClosedError(Exception):
"""


class MQLLMEngineClient:
class MQLLMEngineClient(EngineClient):
"""A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol.

Expand Down Expand Up @@ -316,7 +314,7 @@ async def _check_success(error_message: str, socket: Socket):
or response != VLLM_RPC_SUCCESS_STR):
raise ValueError(error_message)

async def get_tokenizer(self, lora_request: LoRARequest):
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)

async def get_decoding_config(self) -> DecodingConfig:
Expand Down Expand Up @@ -344,10 +342,6 @@ async def abort(self, request_id: str):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id), socket=self.input_socket)

async def do_log_stats(self):
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
pass

async def check_health(self):
"""
The check health loop probes the health status of the
Expand Down Expand Up @@ -444,104 +438,6 @@ def generate(
lora_request, trace_headers,
prompt_adapter_request, priority)

async def beam_search(
self,
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:

beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty

tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)

sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []

for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]

tasks = []

request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)

output = await asyncio.gather(*tasks)

output = [x[0] for x in output]

logger.info(output)

new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]

if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)

sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]

completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]

for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])

beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)

logger.info(beam_search_output)

yield beam_search_output

@overload # DEPRECATED
def encode(
self,
Expand Down
Loading