diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ae44b26a6c55a..1ca9d25411af7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,6 +14,7 @@ from typing import AsyncIterator, Set import uvloop +from entrypoints.openai.fim.fim_encoder import FIMEncoderManager from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -507,6 +508,7 @@ def init_app_state( prompt_adapters=args.prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, + fim_encoder=args.fim, ) state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, @@ -531,11 +533,19 @@ async def run_server(args, **uvicorn_kwargs) -> None: if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: ToolParserManager.import_tool_parser(args.tool_parser_plugin) - valide_tool_parses = ToolParserManager.tool_parsers.keys() - if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valide_tool_parses: - raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " - f"(chose from {{ {','.join(valide_tool_parses)} }})") + if args.enable_auto_tool_choice: + valid_tool_parsers = ToolParserManager.tool_parsers + if args.tool_call_parser not in valid_tool_parsers: + raise KeyError( + f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valid_tool_parsers.keys())} }})") + + if args.fim is not None: + valid_fim_encoders = FIMEncoderManager.fim_encoders + if args.fim not in valid_fim_encoders: + raise KeyError( + f"invalid FIM encoder: {args.fim} " + f"(chose from {{ {','.join(valid_fim_encoders.keys())} }})") # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a089985ac9758..31bda24262415 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -11,6 +11,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import validate_chat_template +from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoderManager from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -213,6 +214,16 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: " into OpenAI API format, the name register in this plugin can be used " "in --tool-call-parser.") + valid_fim_encoders = FIMEncoderManager.fim_encoders.keys() + parser.add_argument( + "--fim", + type=str, + metavar="{" + ",".join(valid_fim_encoders) + "}", + default=None, + help="Select the fill-in-the-middle (FIM) encoder depending on the" + " model that you're using. Required to use the suffix parameter of the" + " OpenAI Completions API.") + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', diff --git a/vllm/entrypoints/openai/fim/__init__.py b/vllm/entrypoints/openai/fim/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/openai/fim/fim_encoder.py b/vllm/entrypoints/openai/fim/fim_encoder.py new file mode 100644 index 0000000000000..2593936f21ebd --- /dev/null +++ b/vllm/entrypoints/openai/fim/fim_encoder.py @@ -0,0 +1,103 @@ +from abc import ABC, abstractmethod +from typing import Callable, Dict, List, Optional, Type, Union + +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import is_list_of + + +class FIMEncoder(ABC): + """ + An encoder of fill-in-the-middle (FIM) prompts comprising prefix + and suffix strings. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.tokenizer = tokenizer + + @abstractmethod + def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]: + """ + Encode the provided prompt prefix and suffix + to a list of token ids + """ + pass + + @classmethod + def for_tokenizer(cls: Type, tokenizer: AnyTokenizer) -> "FIMEncoder": + fim_encoder = getattr(tokenizer, "fim_encoder", None) + if fim_encoder is None: + fim_encoder = cls(tokenizer) + tokenizer.fim_encoder = fim_encoder # type: ignore[union-attr] + return fim_encoder + + +class FIMEncoderManager: + fim_encoders: Dict[str, Type] = {} + + @classmethod + def get_fim_encoder_class(cls, name: Optional[str]) -> Optional[Type]: + """ + Get FIM encoder by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name is None: + return None + + if (encoder := cls.fim_encoders.get(name)) is None: + raise KeyError(f"fim encoder '{name}' not recognized") + + return encoder + + @classmethod + def _register_module(cls, + module: Type, + module_name: Optional[Union[str, List[str]]] = None, + force: bool = True) -> None: + if not issubclass(module, FIMEncoder): + raise TypeError( + f'module must be subclass of FIMEncoder, but got {type(module)}' + ) + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and (exist_module := cls.fim_encoders.get(name) + is not None): + raise KeyError(f'{name} is already registered ' + f'at {exist_module.__module__}') + cls.fim_encoders[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, List[str]]] = None, + force: bool = True, + module: Union[Type, None] = None) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # raise the error ahead of time + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): + raise TypeError( + 'name must be None, an instance of str, or a sequence of str, ' + f'but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register diff --git a/vllm/entrypoints/openai/fim/mistral_fim.py b/vllm/entrypoints/openai/fim/mistral_fim.py new file mode 100644 index 0000000000000..a545338de683d --- /dev/null +++ b/vllm/entrypoints/openai/fim/mistral_fim.py @@ -0,0 +1,21 @@ +from typing import List + +from transformers_utils.tokenizer import AnyTokenizer +from transformers_utils.tokenizers import MistralTokenizer + +from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder, + FIMEncoderManager) + + +@FIMEncoderManager.register_module("mistral") +class MistralFIMEncoder(FIMEncoder): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "tokenizer incompatible with 'mistral' FIM encoder") + + def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]: + return self.tokenizer.encode_with_suffix(prompt=prompt, suffix=suffix) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 56e35950410a0..7fcf0ca3e3a66 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -55,6 +55,7 @@ def __init__( prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + fim_encoder: Optional[str] = None, ): super().__init__(engine_client=engine_client, model_config=model_config, @@ -62,7 +63,8 @@ def __init__( lora_modules=lora_modules, prompt_adapters=prompt_adapters, request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids) + return_tokens_as_token_ids=return_tokens_as_token_ids, + fim_encoder=fim_encoder) async def create_completion( self, @@ -74,9 +76,6 @@ async def create_completion( See https://platform.openai.com/docs/api-reference/completions/create for the API specification. This API mimics the OpenAI Completion API. - NOTE: Currently we do not support the following feature: - - suffix (the language models we currently support do not support - suffix) """ error_check_ret = await self._check_model(request) if error_check_ret is not None: @@ -88,11 +87,6 @@ async def create_completion( if self.engine_client.errored: raise self.engine_client.dead_error - # Return error for unsupported features. - if request.suffix is not None: - return self.create_error_response( - "suffix is not currently supported") - model_name = self.base_model_paths[0].name request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) @@ -116,6 +110,7 @@ async def create_completion( request, tokenizer, request.prompt, + suffix=request.suffix, truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, )) @@ -315,6 +310,14 @@ async def completion_stream_generator( # Chunked prefill case, don't return empty chunks continue + previous_text_lens[i] += len(output.text) + previous_num_tokens[i] += len(output.token_ids) + finish_reason = output.finish_reason + stop_reason = output.stop_reason + + if finish_reason and request.echo and request.suffix: + delta_text += request.suffix + if request.logprobs is not None: assert out_logprobs is not None, ( "Did not output logprobs") @@ -328,11 +331,6 @@ async def completion_stream_generator( else: logprobs = None - previous_text_lens[i] += len(output.text) - previous_num_tokens[i] += len(output.token_ids) - finish_reason = output.finish_reason - stop_reason = output.stop_reason - chunk = CompletionStreamResponse( id=request_id, created=created_time, @@ -400,6 +398,8 @@ def request_output_to_completion_response( num_prompt_tokens = 0 num_generated_tokens = 0 + suffix = "" if request.suffix is None else request.suffix + for final_res in final_res_batch: prompt_token_ids = final_res.prompt_token_ids assert prompt_token_ids is not None @@ -409,7 +409,6 @@ def request_output_to_completion_response( token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[Dict[int, Logprob]]]] - for output in final_res.outputs: assert request.max_tokens is not None if request.echo: @@ -417,7 +416,7 @@ def request_output_to_completion_response( if request.max_tokens == 0: token_ids = prompt_token_ids out_logprobs = prompt_logprobs - output_text = prompt_text + output_text = prompt_text + suffix else: token_ids = [*prompt_token_ids, *output.token_ids] @@ -431,7 +430,7 @@ def request_output_to_completion_response( *output.logprobs, ] - output_text = prompt_text + output.text + output_text = prompt_text + output.text + suffix else: token_ids = output.token_ids out_logprobs = output.logprobs diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 6c46aae2838f6..eb498df1e52d5 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -135,9 +135,11 @@ async def create_embedding( pooling_params = request.to_pooling_params() prompts = list( - self._tokenize_prompt_input_or_inputs(request, tokenizer, - request.input, - truncate_prompt_tokens)) + self._tokenize_prompt_input_or_inputs( + request, + tokenizer, + request.input, + truncate_prompt_tokens=truncate_prompt_tokens)) for i, prompt_inputs in enumerate(prompts): request_id_item = f"{request_id}-{i}" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index e6d2ab93d3363..74ac887fa78e8 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,8 @@ import pathlib from dataclasses import dataclass from http import HTTPStatus -from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union +from typing import (Callable, Iterable, Iterator, List, Optional, Tuple, + TypedDict, Union) from pydantic import Field from typing_extensions import Annotated @@ -10,6 +11,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder, + FIMEncoderManager) # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -77,6 +80,7 @@ def __init__( prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + fim_encoder: Optional[str] = None, ): super().__init__() @@ -117,6 +121,9 @@ def __init__( self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids + self.fim_class: Optional[Callable[[AnyTokenizer], FIMEncoder]] = \ + FIMEncoderManager.get_fim_encoder_class(fim_encoder) + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ @@ -204,22 +211,33 @@ def _normalize_prompt_text_to_input( request: AnyRequest, tokenizer: AnyTokenizer, prompt: str, + suffix: Optional[str], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], add_special_tokens: bool, ) -> TextTokensPrompt: - if truncate_prompt_tokens is None: - encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) + if suffix: + if not (fim_class := self.fim_class): + raise ValueError("fim support must be enabled to use suffix") + if truncate_prompt_tokens is not None: + raise ValueError( + "truncate_prompt_tokens is not supported with suffix") + fim_encoder = fim_class.for_tokenizer( # type: ignore[attr-defined] + tokenizer) + input_ids = fim_encoder.encode_with_suffix(prompt=prompt, + suffix=suffix) else: - encoded = tokenizer(prompt, - add_special_tokens=add_special_tokens, - truncation=True, - max_length=truncate_prompt_tokens) - - input_ids = encoded.input_ids + if truncate_prompt_tokens is None: + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens) + else: + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) - input_text = prompt + input_ids = encoded.input_ids - return self._validate_input(request, input_ids, input_text) + return self._validate_input(request, input_ids, input_text=prompt) def _normalize_prompt_tokens_to_input( self, @@ -307,6 +325,7 @@ def _tokenize_prompt_inputs( request: AnyRequest, tokenizer: AnyTokenizer, prompt_inputs: Iterable[Union[str, List[int]]], + suffix: Optional[str] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, ) -> Iterator[TextTokensPrompt]: @@ -320,10 +339,14 @@ def _tokenize_prompt_inputs( request, tokenizer, prompt=text, + suffix=suffix, truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) else: + if suffix: + raise ValueError( + "suffix is only supported with string prompt input") yield self._normalize_prompt_tokens_to_input( request, tokenizer, @@ -336,6 +359,7 @@ def _tokenize_prompt_input_or_inputs( request: AnyRequest, tokenizer: AnyTokenizer, input_or_inputs: Union[str, List[str], List[int], List[List[int]]], + suffix: Optional[str] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, ) -> Iterator[TextTokensPrompt]: @@ -356,10 +380,14 @@ def _tokenize_prompt_input_or_inputs( request, tokenizer, prompt=prompt_input["content"], + suffix=suffix, truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) else: + if suffix: + raise ValueError( + "suffix is only supported with string prompt input") yield self._normalize_prompt_tokens_to_input( request, tokenizer, diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 23ea657ffb0a9..fee78871b07a6 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -7,6 +7,7 @@ import huggingface_hub from huggingface_hub import HfApi, hf_hub_download from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.instruct.request import FIMRequest # yapf: disable from mistral_common.tokens.tokenizers.mistral import ( MistralTokenizer as PublicMistralTokenizer) @@ -189,6 +190,10 @@ def encode(self, prompt: str) -> List[int]: # For chat completion use `apply_chat_template` return self.tokenizer.encode(prompt, bos=True, eos=False) + def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]: + fim = FIMRequest(prompt=prompt, suffix=suffix) + return self.mistral.encode_fim(fim) + def apply_chat_template(self, messages: List["ChatCompletionMessageParam"], tools: Optional[Dict[str, Any]] = None,