Skip to content

Commit

Permalink
[Frontend] Support suffix in completions API (fill-in-the-middle)
Browse files Browse the repository at this point in the history
Handle model-specific FIM encoding rules in a similar way to how we're handling different tool parsers.
  • Loading branch information
njhill committed Oct 21, 2024
1 parent 711f3a7 commit f57746a
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 36 deletions.
20 changes: 15 additions & 5 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import AsyncIterator, Set

import uvloop
from entrypoints.openai.fim.fim_encoder import FIMEncoderManager
from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
Expand Down Expand Up @@ -507,6 +508,7 @@ def init_app_state(
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
fim_encoder=args.fim,
)
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
Expand All @@ -531,11 +533,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)

valide_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")
if args.enable_auto_tool_choice:
valid_tool_parsers = ToolParserManager.tool_parsers
if args.tool_call_parser not in valid_tool_parsers:
raise KeyError(
f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valid_tool_parsers.keys())} }})")

if args.fim is not None:
valid_fim_encoders = FIMEncoderManager.fim_encoders
if args.fim not in valid_fim_encoders:
raise KeyError(
f"invalid FIM encoder: {args.fim} "
f"(chose from {{ {','.join(valid_fim_encoders.keys())} }})")

# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
Expand Down
11 changes: 11 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import validate_chat_template
from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoderManager
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
Expand Down Expand Up @@ -213,6 +214,16 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser.")

valid_fim_encoders = FIMEncoderManager.fim_encoders.keys()
parser.add_argument(
"--fim",
type=str,
metavar="{" + ",".join(valid_fim_encoders) + "}",
default=None,
help="Select the fill-in-the-middle (FIM) encoder depending on the"
" model that you're using. Required to use the suffix parameter of the"
" OpenAI Completions API.")

parser = AsyncEngineArgs.add_cli_args(parser)

parser.add_argument('--max-log-len',
Expand Down
Empty file.
103 changes: 103 additions & 0 deletions vllm/entrypoints/openai/fim/fim_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Optional, Type, Union

from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import is_list_of


class FIMEncoder(ABC):
"""
An encoder of fill-in-the-middle (FIM) prompts comprising prefix
and suffix strings.
"""

def __init__(self, tokenizer: AnyTokenizer):
self.tokenizer = tokenizer

@abstractmethod
def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]:
"""
Encode the provided prompt prefix and suffix
to a list of token ids
"""
pass

@classmethod
def for_tokenizer(cls: Type, tokenizer: AnyTokenizer) -> "FIMEncoder":
fim_encoder = getattr(tokenizer, "fim_encoder", None)
if fim_encoder is None:
fim_encoder = cls(tokenizer)
tokenizer.fim_encoder = fim_encoder # type: ignore[union-attr]
return fim_encoder


class FIMEncoderManager:
fim_encoders: Dict[str, Type] = {}

@classmethod
def get_fim_encoder_class(cls, name: Optional[str]) -> Optional[Type]:
"""
Get FIM encoder by name which is registered by `register_module`.
Raise a KeyError exception if the name is not registered.
"""
if name is None:
return None

if (encoder := cls.fim_encoders.get(name)) is None:
raise KeyError(f"fim encoder '{name}' not recognized")

return encoder

@classmethod
def _register_module(cls,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = True) -> None:
if not issubclass(module, FIMEncoder):
raise TypeError(
f'module must be subclass of FIMEncoder, but got {type(module)}'
)
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and (exist_module := cls.fim_encoders.get(name)
is not None):
raise KeyError(f'{name} is already registered '
f'at {exist_module.__module__}')
cls.fim_encoders[name] = module

@classmethod
def register_module(
cls,
name: Optional[Union[str, List[str]]] = None,
force: bool = True,
module: Union[Type, None] = None) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')

# raise the error ahead of time
if not (name is None or isinstance(name, str)
or is_list_of(name, str)):
raise TypeError(
'name must be None, an instance of str, or a sequence of str, '
f'but got {type(name)}')

# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module

# use it as a decorator: @x.register_module()
def _register(module):
cls._register_module(module=module, module_name=name, force=force)
return module

return _register
21 changes: 21 additions & 0 deletions vllm/entrypoints/openai/fim/mistral_fim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List

from transformers_utils.tokenizer import AnyTokenizer
from transformers_utils.tokenizers import MistralTokenizer

from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder,
FIMEncoderManager)


@FIMEncoderManager.register_module("mistral")
class MistralFIMEncoder(FIMEncoder):

def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)

if not isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"tokenizer incompatible with 'mistral' FIM encoder")

def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]:
return self.tokenizer.encode_with_suffix(prompt=prompt, suffix=suffix)
33 changes: 16 additions & 17 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,16 @@ def __init__(
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
fim_encoder: Optional[str] = None,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
return_tokens_as_token_ids=return_tokens_as_token_ids,
fim_encoder=fim_encoder)

async def create_completion(
self,
Expand All @@ -74,9 +76,6 @@ async def create_completion(
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
Expand All @@ -88,11 +87,6 @@ async def create_completion(
if self.engine_client.errored:
raise self.engine_client.dead_error

# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")

model_name = self.base_model_paths[0].name
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
Expand All @@ -116,6 +110,7 @@ async def create_completion(
request,
tokenizer,
request.prompt,
suffix=request.suffix,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
))
Expand Down Expand Up @@ -315,6 +310,14 @@ async def completion_stream_generator(
# Chunked prefill case, don't return empty chunks
continue

previous_text_lens[i] += len(output.text)
previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason

if finish_reason and request.echo and request.suffix:
delta_text += request.suffix

if request.logprobs is not None:
assert out_logprobs is not None, (
"Did not output logprobs")
Expand All @@ -328,11 +331,6 @@ async def completion_stream_generator(
else:
logprobs = None

previous_text_lens[i] += len(output.text)
previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason

chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
Expand Down Expand Up @@ -400,6 +398,8 @@ def request_output_to_completion_response(
num_prompt_tokens = 0
num_generated_tokens = 0

suffix = "" if request.suffix is None else request.suffix

for final_res in final_res_batch:
prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
Expand All @@ -409,15 +409,14 @@ def request_output_to_completion_response(
token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
Logprob]]]]

for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo:
assert prompt_text is not None
if request.max_tokens == 0:
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
output_text = prompt_text + suffix
else:
token_ids = [*prompt_token_ids, *output.token_ids]

Expand All @@ -431,7 +430,7 @@ def request_output_to_completion_response(
*output.logprobs,
]

output_text = prompt_text + output.text
output_text = prompt_text + output.text + suffix
else:
token_ids = output.token_ids
out_logprobs = output.logprobs
Expand Down
8 changes: 5 additions & 3 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,11 @@ async def create_embedding(
pooling_params = request.to_pooling_params()

prompts = list(
self._tokenize_prompt_input_or_inputs(request, tokenizer,
request.input,
truncate_prompt_tokens))
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens))

for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}"
Expand Down
Loading

0 comments on commit f57746a

Please sign in to comment.