Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Support suffix in completions API (fill-in-the-middle, FIM) #9522

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
16 changes: 16 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.entrypoints.openai.fim import get_supported_fim_encoders
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
Expand Down Expand Up @@ -828,6 +829,7 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
fim_encoder=args.fim,
) if model_config.runner_type == "generate" else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
Expand Down Expand Up @@ -894,6 +896,20 @@ async def run_server(args, **uvicorn_kwargs) -> None:
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)

if args.enable_auto_tool_choice:
valid_tool_parsers = ToolParserManager.tool_parsers
if args.tool_call_parser not in valid_tool_parsers:
raise KeyError(
f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valid_tool_parsers.keys())} }})")

if args.fim is not None:
valid_fim_encoders = get_supported_fim_encoders()
if args.fim not in valid_fim_encoders:
raise KeyError(
f"invalid FIM encoder: {args.fim} "
f"(chose from {{ {','.join(valid_fim_encoders)} }})")

valid_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valid_tool_parses:
Expand Down
11 changes: 11 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.fim import get_supported_fim_encoders
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
PromptAdapterPath)
Expand Down Expand Up @@ -249,6 +250,16 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
" into OpenAI API format, the name register in this plugin can be used "
"in ``--tool-call-parser``.")

valid_fim_encoders = get_supported_fim_encoders()
parser.add_argument(
"--fim",
type=str,
metavar="{" + ",".join(valid_fim_encoders) + "}",
default=None,
help="Select the fill-in-the-middle (FIM) encoder depending on the"
" model that you're using. Required to use the suffix parameter of the"
" OpenAI Completions API.")

parser = AsyncEngineArgs.add_cli_args(parser)

parser.add_argument('--max-log-len',
Expand Down
70 changes: 70 additions & 0 deletions vllm/entrypoints/openai/fim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from inspect import isclass
from typing import Callable, Dict, Iterable, Optional, Tuple, Type, Union

from vllm.entrypoints.openai.fim.codellama_fim import CodeLlamaFIMEncoder
from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder,
StringTemplateFIMEncoder)
from vllm.entrypoints.openai.fim.mistral_fim import MistralFIMEncoder
from vllm.transformers_utils.tokenizer import AnyTokenizer

__all__ = [
"FIMEncoder", "get_supported_fim_encoders", "get_fim_encoder_lookup"
]

# Entries are either an FIMEncoder implementation class or
# tuple of (template, special_tokens_list).
_FIM_ENCODERS: Dict[str, Union[Type, Tuple[str, Iterable[str]]]] = {
"mistral":
MistralFIMEncoder,
"codellama":
CodeLlamaFIMEncoder,
"deepseek": (
"<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>",
("<|fim▁begin|>", "<|fim▁hole|>", "<|fim▁end|>"),
),
"starcoder": (
"<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>",
("<fim_prefix>", "<fim_suffix>", "<fim_middle>"),
)
}


def get_supported_fim_encoders() -> Iterable[str]:
"""Return set of supported FIM encoder types."""
return _FIM_ENCODERS.keys()


def get_fim_encoder_lookup(
name: Optional[str]) -> Optional[Callable[[AnyTokenizer], FIMEncoder]]:
"""
Get a function that returns a FIMEncoder instance for a given tokenizer.
Raise a KeyError exception if the name is not recognized.
"""
if name is None:
return None

if (encoder := _FIM_ENCODERS.get(name)) is None:
raise ValueError(f"fim encoder '{name}' not recognized")

factory: Callable[[AnyTokenizer], FIMEncoder]
if isclass(encoder):
assert issubclass(encoder, FIMEncoder)
factory = encoder
else:
assert isinstance(encoder, tuple)
template, special_tokens = encoder
factory = partial(StringTemplateFIMEncoder,
name=name,
template=template,
special_tokens=special_tokens)

def for_tokenizer(tokenizer: AnyTokenizer) -> FIMEncoder:
fim_encoder = getattr(tokenizer, "fim_encoder", None)
if fim_encoder is None:
fim_encoder = factory(tokenizer)
tokenizer.fim_encoder = fim_encoder # type: ignore[union-attr]
return fim_encoder

return for_tokenizer
42 changes: 42 additions & 0 deletions vllm/entrypoints/openai/fim/codellama_fim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List

from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder
from vllm.transformers_utils.tokenizer import AnyTokenizer


class CodeLlamaFIMEncoder(FIMEncoder):
"""
FIM Encoder for Meta CodeLlama models

Adapted from https://github.com/meta-llama/codellama/blob/e81b597e44dbecc2a0dedb9949fdf84adfc22395/llama/generation.py#L474
"""

def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)

if not hasattr(tokenizer, "convert_tokens_to_ids"):
raise ValueError(
"tokenizer incompatible with 'codellama' FIM encoder")

self.bos_id = tokenizer.convert_tokens_to_ids("<s>")
self.prefix_id = tokenizer.convert_tokens_to_ids("▁<PRE>")
self.suffix_id = tokenizer.convert_tokens_to_ids("▁<SUF>")
self.middle_id = tokenizer.convert_tokens_to_ids("▁<MID>")

unk_token_id = getattr(tokenizer, "unk_token_id", None)
if any(tid in
{self.bos_id, self.prefix_id, self.suffix_id, self.middle_id}
for tid in (None, unk_token_id)):
raise ValueError(
"tokenizer incompatible with 'codellama' FIM encoder")

def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
prefix_tokens = self.tokenizer(prefix,
add_special_tokens=False).input_ids
# Encode a string without an implicit leading space.
suffix_tokens = self.tokenizer("☺" + suffix,
add_special_tokens=False).input_ids[2:]

return ([self.bos_id, self.prefix_id] + prefix_tokens[self.suffix_id] +
suffix_tokens + [self.middle_id])
53 changes: 53 additions & 0 deletions vllm/entrypoints/openai/fim/fim_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Iterable, List, Optional

from vllm.transformers_utils.tokenizer import AnyTokenizer


class FIMEncoder(ABC):
"""
An encoder of fill-in-the-middle (FIM) prompts comprising prefix
and suffix strings.
"""

def __init__(self, tokenizer: AnyTokenizer):
self.tokenizer = tokenizer

@abstractmethod
def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
"""
Encode the provided prompt prefix and suffix
to a list of token ids
"""
pass


class StringTemplateFIMEncoder(FIMEncoder):
"""FIMEncoder implementation using a simple string template
with prefix and suffix variables."""

def __init__(
self,
tokenizer: AnyTokenizer,
name: str,
template: str,
special_tokens: Optional[Iterable[str]] = None,
):
super().__init__(tokenizer)

if not hasattr(tokenizer, "convert_tokens_to_ids"):
raise ValueError(
"tokenizer incompatible with 'codellama' FIM encoder")

unk_token_id = getattr(tokenizer, "unk_token_id", None)
for special_token in special_tokens or ():
token_id = tokenizer.convert_tokens_to_ids(special_token)
if token_id is None or token_id == unk_token_id:
raise ValueError(
f"tokenizer incompatible with '{name}' FIM encoder")
self.template = template

def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
prompt = self.template.format(prefix=prefix, suffix=suffix)
return self.tokenizer(prompt, add_special_tokens=False).input_ids
24 changes: 24 additions & 0 deletions vllm/entrypoints/openai/fim/mistral_fim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List

from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerV2

from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers import MistralTokenizer


class MistralFIMEncoder(FIMEncoder):

def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)

# InstructTokenizerV3 is a subclass of InstructTokenizerV2
if not isinstance(tokenizer, MistralTokenizer) \
or not isinstance(tokenizer.instruct, InstructTokenizerV2):
raise ValueError(
"tokenizer incompatible with 'mistral' FIM encoder")

def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
assert isinstance(self.tokenizer, MistralTokenizer)
return self.tokenizer.encode_with_suffix(prefix=prefix, suffix=suffix)
14 changes: 11 additions & 3 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ def __init__(
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
fim_encoder: Optional[str] = None,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
return_tokens_as_token_ids=return_tokens_as_token_ids,
fim_encoder=fim_encoder)
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info(
Expand Down Expand Up @@ -105,6 +107,7 @@ async def create_completion(
request,
tokenizer,
request.prompt,
request.suffix,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
Expand Down Expand Up @@ -326,6 +329,9 @@ async def completion_stream_generator(
finish_reason = output.finish_reason
stop_reason = output.stop_reason

if finish_reason and request.echo and request.suffix:
delta_text += request.suffix

chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
Expand Down Expand Up @@ -393,6 +399,8 @@ def request_output_to_completion_response(
num_prompt_tokens = 0
num_generated_tokens = 0

suffix = "" if request.suffix is None else request.suffix

for final_res in final_res_batch:
prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
Expand All @@ -416,7 +424,7 @@ def request_output_to_completion_response(
if request.max_tokens == 0:
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
output_text = prompt_text + suffix
else:
token_ids = [*prompt_token_ids, *output.token_ids]

Expand All @@ -430,7 +438,7 @@ def request_output_to_completion_response(
*output.logprobs,
]

output_text = prompt_text + output.text
output_text = prompt_text + output.text + suffix
else:
token_ids = output.token_ids
out_logprobs = output.logprobs
Expand Down
Loading