From f57746ae20c4cdebe172a8530650ab4432227525 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 16 Oct 2024 16:40:04 -0700 Subject: [PATCH 1/7] [Frontend] Support suffix in completions API (fill-in-the-middle) Handle model-specific FIM encoding rules in a similar way to how we're handling different tool parsers. --- vllm/entrypoints/openai/api_server.py | 20 +++- vllm/entrypoints/openai/cli_args.py | 11 ++ vllm/entrypoints/openai/fim/__init__.py | 0 vllm/entrypoints/openai/fim/fim_encoder.py | 103 ++++++++++++++++++ vllm/entrypoints/openai/fim/mistral_fim.py | 21 ++++ vllm/entrypoints/openai/serving_completion.py | 33 +++--- vllm/entrypoints/openai/serving_embedding.py | 8 +- vllm/entrypoints/openai/serving_engine.py | 50 +++++++-- vllm/transformers_utils/tokenizers/mistral.py | 5 + 9 files changed, 215 insertions(+), 36 deletions(-) create mode 100644 vllm/entrypoints/openai/fim/__init__.py create mode 100644 vllm/entrypoints/openai/fim/fim_encoder.py create mode 100644 vllm/entrypoints/openai/fim/mistral_fim.py diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ae44b26a6c55a..1ca9d25411af7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,6 +14,7 @@ from typing import AsyncIterator, Set import uvloop +from entrypoints.openai.fim.fim_encoder import FIMEncoderManager from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -507,6 +508,7 @@ def init_app_state( prompt_adapters=args.prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, + fim_encoder=args.fim, ) state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, @@ -531,11 +533,19 @@ async def run_server(args, **uvicorn_kwargs) -> None: if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: ToolParserManager.import_tool_parser(args.tool_parser_plugin) - valide_tool_parses = ToolParserManager.tool_parsers.keys() - if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valide_tool_parses: - raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " - f"(chose from {{ {','.join(valide_tool_parses)} }})") + if args.enable_auto_tool_choice: + valid_tool_parsers = ToolParserManager.tool_parsers + if args.tool_call_parser not in valid_tool_parsers: + raise KeyError( + f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valid_tool_parsers.keys())} }})") + + if args.fim is not None: + valid_fim_encoders = FIMEncoderManager.fim_encoders + if args.fim not in valid_fim_encoders: + raise KeyError( + f"invalid FIM encoder: {args.fim} " + f"(chose from {{ {','.join(valid_fim_encoders.keys())} }})") # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a089985ac9758..31bda24262415 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -11,6 +11,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import validate_chat_template +from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoderManager from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -213,6 +214,16 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: " into OpenAI API format, the name register in this plugin can be used " "in --tool-call-parser.") + valid_fim_encoders = FIMEncoderManager.fim_encoders.keys() + parser.add_argument( + "--fim", + type=str, + metavar="{" + ",".join(valid_fim_encoders) + "}", + default=None, + help="Select the fill-in-the-middle (FIM) encoder depending on the" + " model that you're using. Required to use the suffix parameter of the" + " OpenAI Completions API.") + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', diff --git a/vllm/entrypoints/openai/fim/__init__.py b/vllm/entrypoints/openai/fim/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/entrypoints/openai/fim/fim_encoder.py b/vllm/entrypoints/openai/fim/fim_encoder.py new file mode 100644 index 0000000000000..2593936f21ebd --- /dev/null +++ b/vllm/entrypoints/openai/fim/fim_encoder.py @@ -0,0 +1,103 @@ +from abc import ABC, abstractmethod +from typing import Callable, Dict, List, Optional, Type, Union + +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import is_list_of + + +class FIMEncoder(ABC): + """ + An encoder of fill-in-the-middle (FIM) prompts comprising prefix + and suffix strings. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.tokenizer = tokenizer + + @abstractmethod + def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]: + """ + Encode the provided prompt prefix and suffix + to a list of token ids + """ + pass + + @classmethod + def for_tokenizer(cls: Type, tokenizer: AnyTokenizer) -> "FIMEncoder": + fim_encoder = getattr(tokenizer, "fim_encoder", None) + if fim_encoder is None: + fim_encoder = cls(tokenizer) + tokenizer.fim_encoder = fim_encoder # type: ignore[union-attr] + return fim_encoder + + +class FIMEncoderManager: + fim_encoders: Dict[str, Type] = {} + + @classmethod + def get_fim_encoder_class(cls, name: Optional[str]) -> Optional[Type]: + """ + Get FIM encoder by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name is None: + return None + + if (encoder := cls.fim_encoders.get(name)) is None: + raise KeyError(f"fim encoder '{name}' not recognized") + + return encoder + + @classmethod + def _register_module(cls, + module: Type, + module_name: Optional[Union[str, List[str]]] = None, + force: bool = True) -> None: + if not issubclass(module, FIMEncoder): + raise TypeError( + f'module must be subclass of FIMEncoder, but got {type(module)}' + ) + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and (exist_module := cls.fim_encoders.get(name) + is not None): + raise KeyError(f'{name} is already registered ' + f'at {exist_module.__module__}') + cls.fim_encoders[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, List[str]]] = None, + force: bool = True, + module: Union[Type, None] = None) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # raise the error ahead of time + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): + raise TypeError( + 'name must be None, an instance of str, or a sequence of str, ' + f'but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register diff --git a/vllm/entrypoints/openai/fim/mistral_fim.py b/vllm/entrypoints/openai/fim/mistral_fim.py new file mode 100644 index 0000000000000..a545338de683d --- /dev/null +++ b/vllm/entrypoints/openai/fim/mistral_fim.py @@ -0,0 +1,21 @@ +from typing import List + +from transformers_utils.tokenizer import AnyTokenizer +from transformers_utils.tokenizers import MistralTokenizer + +from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder, + FIMEncoderManager) + + +@FIMEncoderManager.register_module("mistral") +class MistralFIMEncoder(FIMEncoder): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "tokenizer incompatible with 'mistral' FIM encoder") + + def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]: + return self.tokenizer.encode_with_suffix(prompt=prompt, suffix=suffix) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 56e35950410a0..7fcf0ca3e3a66 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -55,6 +55,7 @@ def __init__( prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + fim_encoder: Optional[str] = None, ): super().__init__(engine_client=engine_client, model_config=model_config, @@ -62,7 +63,8 @@ def __init__( lora_modules=lora_modules, prompt_adapters=prompt_adapters, request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids) + return_tokens_as_token_ids=return_tokens_as_token_ids, + fim_encoder=fim_encoder) async def create_completion( self, @@ -74,9 +76,6 @@ async def create_completion( See https://platform.openai.com/docs/api-reference/completions/create for the API specification. This API mimics the OpenAI Completion API. - NOTE: Currently we do not support the following feature: - - suffix (the language models we currently support do not support - suffix) """ error_check_ret = await self._check_model(request) if error_check_ret is not None: @@ -88,11 +87,6 @@ async def create_completion( if self.engine_client.errored: raise self.engine_client.dead_error - # Return error for unsupported features. - if request.suffix is not None: - return self.create_error_response( - "suffix is not currently supported") - model_name = self.base_model_paths[0].name request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) @@ -116,6 +110,7 @@ async def create_completion( request, tokenizer, request.prompt, + suffix=request.suffix, truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, )) @@ -315,6 +310,14 @@ async def completion_stream_generator( # Chunked prefill case, don't return empty chunks continue + previous_text_lens[i] += len(output.text) + previous_num_tokens[i] += len(output.token_ids) + finish_reason = output.finish_reason + stop_reason = output.stop_reason + + if finish_reason and request.echo and request.suffix: + delta_text += request.suffix + if request.logprobs is not None: assert out_logprobs is not None, ( "Did not output logprobs") @@ -328,11 +331,6 @@ async def completion_stream_generator( else: logprobs = None - previous_text_lens[i] += len(output.text) - previous_num_tokens[i] += len(output.token_ids) - finish_reason = output.finish_reason - stop_reason = output.stop_reason - chunk = CompletionStreamResponse( id=request_id, created=created_time, @@ -400,6 +398,8 @@ def request_output_to_completion_response( num_prompt_tokens = 0 num_generated_tokens = 0 + suffix = "" if request.suffix is None else request.suffix + for final_res in final_res_batch: prompt_token_ids = final_res.prompt_token_ids assert prompt_token_ids is not None @@ -409,7 +409,6 @@ def request_output_to_completion_response( token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[Dict[int, Logprob]]]] - for output in final_res.outputs: assert request.max_tokens is not None if request.echo: @@ -417,7 +416,7 @@ def request_output_to_completion_response( if request.max_tokens == 0: token_ids = prompt_token_ids out_logprobs = prompt_logprobs - output_text = prompt_text + output_text = prompt_text + suffix else: token_ids = [*prompt_token_ids, *output.token_ids] @@ -431,7 +430,7 @@ def request_output_to_completion_response( *output.logprobs, ] - output_text = prompt_text + output.text + output_text = prompt_text + output.text + suffix else: token_ids = output.token_ids out_logprobs = output.logprobs diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 6c46aae2838f6..eb498df1e52d5 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -135,9 +135,11 @@ async def create_embedding( pooling_params = request.to_pooling_params() prompts = list( - self._tokenize_prompt_input_or_inputs(request, tokenizer, - request.input, - truncate_prompt_tokens)) + self._tokenize_prompt_input_or_inputs( + request, + tokenizer, + request.input, + truncate_prompt_tokens=truncate_prompt_tokens)) for i, prompt_inputs in enumerate(prompts): request_id_item = f"{request_id}-{i}" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index e6d2ab93d3363..74ac887fa78e8 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,8 @@ import pathlib from dataclasses import dataclass from http import HTTPStatus -from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union +from typing import (Callable, Iterable, Iterator, List, Optional, Tuple, + TypedDict, Union) from pydantic import Field from typing_extensions import Annotated @@ -10,6 +11,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder, + FIMEncoderManager) # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -77,6 +80,7 @@ def __init__( prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + fim_encoder: Optional[str] = None, ): super().__init__() @@ -117,6 +121,9 @@ def __init__( self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids + self.fim_class: Optional[Callable[[AnyTokenizer], FIMEncoder]] = \ + FIMEncoderManager.get_fim_encoder_class(fim_encoder) + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ @@ -204,22 +211,33 @@ def _normalize_prompt_text_to_input( request: AnyRequest, tokenizer: AnyTokenizer, prompt: str, + suffix: Optional[str], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], add_special_tokens: bool, ) -> TextTokensPrompt: - if truncate_prompt_tokens is None: - encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) + if suffix: + if not (fim_class := self.fim_class): + raise ValueError("fim support must be enabled to use suffix") + if truncate_prompt_tokens is not None: + raise ValueError( + "truncate_prompt_tokens is not supported with suffix") + fim_encoder = fim_class.for_tokenizer( # type: ignore[attr-defined] + tokenizer) + input_ids = fim_encoder.encode_with_suffix(prompt=prompt, + suffix=suffix) else: - encoded = tokenizer(prompt, - add_special_tokens=add_special_tokens, - truncation=True, - max_length=truncate_prompt_tokens) - - input_ids = encoded.input_ids + if truncate_prompt_tokens is None: + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens) + else: + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) - input_text = prompt + input_ids = encoded.input_ids - return self._validate_input(request, input_ids, input_text) + return self._validate_input(request, input_ids, input_text=prompt) def _normalize_prompt_tokens_to_input( self, @@ -307,6 +325,7 @@ def _tokenize_prompt_inputs( request: AnyRequest, tokenizer: AnyTokenizer, prompt_inputs: Iterable[Union[str, List[int]]], + suffix: Optional[str] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, ) -> Iterator[TextTokensPrompt]: @@ -320,10 +339,14 @@ def _tokenize_prompt_inputs( request, tokenizer, prompt=text, + suffix=suffix, truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) else: + if suffix: + raise ValueError( + "suffix is only supported with string prompt input") yield self._normalize_prompt_tokens_to_input( request, tokenizer, @@ -336,6 +359,7 @@ def _tokenize_prompt_input_or_inputs( request: AnyRequest, tokenizer: AnyTokenizer, input_or_inputs: Union[str, List[str], List[int], List[List[int]]], + suffix: Optional[str] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, ) -> Iterator[TextTokensPrompt]: @@ -356,10 +380,14 @@ def _tokenize_prompt_input_or_inputs( request, tokenizer, prompt=prompt_input["content"], + suffix=suffix, truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) else: + if suffix: + raise ValueError( + "suffix is only supported with string prompt input") yield self._normalize_prompt_tokens_to_input( request, tokenizer, diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 23ea657ffb0a9..fee78871b07a6 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -7,6 +7,7 @@ import huggingface_hub from huggingface_hub import HfApi, hf_hub_download from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.instruct.request import FIMRequest # yapf: disable from mistral_common.tokens.tokenizers.mistral import ( MistralTokenizer as PublicMistralTokenizer) @@ -189,6 +190,10 @@ def encode(self, prompt: str) -> List[int]: # For chat completion use `apply_chat_template` return self.tokenizer.encode(prompt, bos=True, eos=False) + def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]: + fim = FIMRequest(prompt=prompt, suffix=suffix) + return self.mistral.encode_fim(fim) + def apply_chat_template(self, messages: List["ChatCompletionMessageParam"], tools: Optional[Dict[str, Any]] = None, From 72b2eb1110168245d429a50f917cc386ad75103c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 21 Oct 2024 16:43:40 -0700 Subject: [PATCH 2/7] Simplify registration --- vllm/entrypoints/openai/api_server.py | 6 +- vllm/entrypoints/openai/cli_args.py | 4 +- vllm/entrypoints/openai/fim/codellama_fim.py | 41 +++++ vllm/entrypoints/openai/fim/fim_encoder.py | 163 +++++++++--------- vllm/entrypoints/openai/fim/mistral_fim.py | 16 +- vllm/entrypoints/openai/serving_engine.py | 13 +- vllm/transformers_utils/tokenizers/mistral.py | 4 +- 7 files changed, 148 insertions(+), 99 deletions(-) create mode 100644 vllm/entrypoints/openai/fim/codellama_fim.py diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1ca9d25411af7..976ad18b94897 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,7 +14,6 @@ from typing import AsyncIterator, Set import uvloop -from entrypoints.openai.fim.fim_encoder import FIMEncoderManager from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -34,6 +33,7 @@ from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) +from vllm.entrypoints.openai.fim.fim_encoder import get_supported_fim_encoders # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -541,11 +541,11 @@ async def run_server(args, **uvicorn_kwargs) -> None: f"(chose from {{ {','.join(valid_tool_parsers.keys())} }})") if args.fim is not None: - valid_fim_encoders = FIMEncoderManager.fim_encoders + valid_fim_encoders = get_supported_fim_encoders() if args.fim not in valid_fim_encoders: raise KeyError( f"invalid FIM encoder: {args.fim} " - f"(chose from {{ {','.join(valid_fim_encoders.keys())} }})") + f"(chose from {{ {','.join(valid_fim_encoders)} }})") # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 31bda24262415..ddbf13ec77abd 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -11,7 +11,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import validate_chat_template -from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoderManager +from vllm.entrypoints.openai.fim.fim_encoder import get_supported_fim_encoders from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -214,7 +214,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: " into OpenAI API format, the name register in this plugin can be used " "in --tool-call-parser.") - valid_fim_encoders = FIMEncoderManager.fim_encoders.keys() + valid_fim_encoders = get_supported_fim_encoders() parser.add_argument( "--fim", type=str, diff --git a/vllm/entrypoints/openai/fim/codellama_fim.py b/vllm/entrypoints/openai/fim/codellama_fim.py new file mode 100644 index 0000000000000..ecbb212b45632 --- /dev/null +++ b/vllm/entrypoints/openai/fim/codellama_fim.py @@ -0,0 +1,41 @@ +from typing import List + +from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class CodeLlamaFIMEncoder(FIMEncoder): + """ + FIM Encoder for Meta CodeLlama models + + Adapted from https://github.com/meta-llama/codellama/blob/e81b597e44dbecc2a0dedb9949fdf84adfc22395/llama/generation.py#L474 + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not hasattr(tokenizer, "convert_tokens_to_ids"): + raise ValueError( + "tokenizer incompatible with 'codellama' FIM encoder") + + self.bos_id = tokenizer.convert_tokens_to_ids("") + self.prefix_id = tokenizer.convert_tokens_to_ids("▁
")
+        self.suffix_id = tokenizer.convert_tokens_to_ids("▁")
+        self.middle_id = tokenizer.convert_tokens_to_ids("▁")
+
+        unk_token_id = getattr(tokenizer, "unk_token_id", None)
+        if any(tid in
+               {self.bos_id, self.prefix_id, self.suffix_id, self.middle_id}
+               for tid in (None, unk_token_id)):
+            raise ValueError(
+                "tokenizer incompatible with 'codellama' FIM encoder")
+
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        return ([self.bos_id, self.prefix_id] +
+                self.tokenizer(prefix, add_special_tokens=False) +
+                [self.suffix_id] + self._encode_infilling(suffix) +
+                [self.middle_id])
+
+    def _encode_infilling(self, s: str) -> List[int]:
+        """Encode a string without an implicit leading space."""
+        return self.tokenizer("☺" + s, add_special_tokens=False)[2:]
diff --git a/vllm/entrypoints/openai/fim/fim_encoder.py b/vllm/entrypoints/openai/fim/fim_encoder.py
index 2593936f21ebd..7cee051dfcfb6 100644
--- a/vllm/entrypoints/openai/fim/fim_encoder.py
+++ b/vllm/entrypoints/openai/fim/fim_encoder.py
@@ -1,8 +1,28 @@
 from abc import ABC, abstractmethod
-from typing import Callable, Dict, List, Optional, Type, Union
+from functools import partial
+from inspect import isclass
+from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
 
+from vllm.entrypoints.openai.fim.codellama_fim import CodeLlamaFIMEncoder
+from vllm.entrypoints.openai.fim.mistral_fim import MistralFIMEncoder
 from vllm.transformers_utils.tokenizer import AnyTokenizer
-from vllm.utils import is_list_of
+
+# Entries are either an FIMEncoder implementation class or
+# tuple of (template, special_tokens_list).
+_FIM_ENCODERS: Dict[str, Union[Type, Tuple[str, Iterable[str]]]] = {
+    "mistral":
+    MistralFIMEncoder,
+    "codellama":
+    CodeLlamaFIMEncoder,
+    "deepseek": (
+        "<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>",
+        ("<|fim▁begin|>", "<|fim▁hole|>", "<|fim▁end|>"),
+    ),
+    "starcoder": (
+        "{prefix}{suffix}",
+        ("", "", ""),
+    )
+}
 
 
 class FIMEncoder(ABC):
@@ -15,89 +35,78 @@ def __init__(self, tokenizer: AnyTokenizer):
         self.tokenizer = tokenizer
 
     @abstractmethod
-    def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]:
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
         """
         Encode the provided prompt prefix and suffix
         to a list of token ids
         """
         pass
 
-    @classmethod
-    def for_tokenizer(cls: Type, tokenizer: AnyTokenizer) -> "FIMEncoder":
+
+class StringTemplateFIMEncoder(FIMEncoder):
+    """FIMEncoder implementation using a simple string template
+    with prefix and suffix variables."""
+
+    def __init__(
+        self,
+        name: str,
+        tokenizer: AnyTokenizer,
+        template: str,
+        special_tokens: Optional[Iterable[str]] = None,
+    ):
+        super().__init__(tokenizer)
+
+        if not hasattr(tokenizer, "convert_tokens_to_ids"):
+            raise ValueError(
+                "tokenizer incompatible with 'codellama' FIM encoder")
+
+        unk_token_id = getattr(tokenizer, "unk_token_id", None)
+        for special_token in special_tokens or ():
+            token_id = tokenizer.convert_tokens_to_ids(special_token)
+            if token_id is None or token_id == unk_token_id:
+                raise ValueError(
+                    f"tokenizer incompatible with '{name}' FIM encoder")
+        self.template = template
+
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        prompt = self.template.format(prefix=prefix, suffix=suffix)
+        return self.tokenizer(prompt, add_special_tokens=False)
+
+
+def get_supported_fim_encoders() -> Iterable[str]:
+    """Return set of supported FIM encoder types."""
+    return _FIM_ENCODERS.keys()
+
+
+def get_fim_encoder_lookup(
+        name: Optional[str]) -> Optional[Callable[[AnyTokenizer], FIMEncoder]]:
+    """
+    Get a function that returns a FIMEncoder instance for a given tokenizer.
+    Raise a KeyError exception if the name is not recognized.
+    """
+    if name is None:
+        return None
+
+    if (encoder := _FIM_ENCODERS.get(name)) is None:
+        raise KeyError(f"fim encoder '{name}' not recognized")
+
+    factory: Callable[[AnyTokenizer], FIMEncoder]
+    if isclass(encoder):
+        assert issubclass(encoder, FIMEncoder)
+        factory = encoder
+    else:
+        assert isinstance(encoder, tuple)
+        template, special_tokens = encoder
+        factory = partial(StringTemplateFIMEncoder,
+                          name=name,
+                          template=template,
+                          special_tokens=special_tokens)
+
+    def for_tokenizer(tokenizer: AnyTokenizer) -> FIMEncoder:
         fim_encoder = getattr(tokenizer, "fim_encoder", None)
         if fim_encoder is None:
-            fim_encoder = cls(tokenizer)
+            fim_encoder = factory(tokenizer)
             tokenizer.fim_encoder = fim_encoder  # type: ignore[union-attr]
         return fim_encoder
 
-
-class FIMEncoderManager:
-    fim_encoders: Dict[str, Type] = {}
-
-    @classmethod
-    def get_fim_encoder_class(cls, name: Optional[str]) -> Optional[Type]:
-        """
-        Get FIM encoder by name which is registered by `register_module`.
-
-        Raise a KeyError exception if the name is not registered.
-        """
-        if name is None:
-            return None
-
-        if (encoder := cls.fim_encoders.get(name)) is None:
-            raise KeyError(f"fim encoder '{name}' not recognized")
-
-        return encoder
-
-    @classmethod
-    def _register_module(cls,
-                         module: Type,
-                         module_name: Optional[Union[str, List[str]]] = None,
-                         force: bool = True) -> None:
-        if not issubclass(module, FIMEncoder):
-            raise TypeError(
-                f'module must be subclass of FIMEncoder, but got {type(module)}'
-            )
-        if module_name is None:
-            module_name = module.__name__
-        if isinstance(module_name, str):
-            module_name = [module_name]
-        for name in module_name:
-            if not force and (exist_module := cls.fim_encoders.get(name)
-                              is not None):
-                raise KeyError(f'{name} is already registered '
-                               f'at {exist_module.__module__}')
-            cls.fim_encoders[name] = module
-
-    @classmethod
-    def register_module(
-            cls,
-            name: Optional[Union[str, List[str]]] = None,
-            force: bool = True,
-            module: Union[Type, None] = None) -> Union[type, Callable]:
-        """
-        Register module with the given name or name list. it can be used as a
-        decoder(with module as None) or normal function(with module as not
-        None).
-        """
-        if not isinstance(force, bool):
-            raise TypeError(f'force must be a boolean, but got {type(force)}')
-
-        # raise the error ahead of time
-        if not (name is None or isinstance(name, str)
-                or is_list_of(name, str)):
-            raise TypeError(
-                'name must be None, an instance of str, or a sequence of str, '
-                f'but got {type(name)}')
-
-        # use it as a normal method: x.register_module(module=SomeClass)
-        if module is not None:
-            cls._register_module(module=module, module_name=name, force=force)
-            return module
-
-        # use it as a decorator: @x.register_module()
-        def _register(module):
-            cls._register_module(module=module, module_name=name, force=force)
-            return module
-
-        return _register
+    return for_tokenizer
diff --git a/vllm/entrypoints/openai/fim/mistral_fim.py b/vllm/entrypoints/openai/fim/mistral_fim.py
index a545338de683d..a18e848c67fa9 100644
--- a/vllm/entrypoints/openai/fim/mistral_fim.py
+++ b/vllm/entrypoints/openai/fim/mistral_fim.py
@@ -1,21 +1,21 @@
 from typing import List
 
-from transformers_utils.tokenizer import AnyTokenizer
-from transformers_utils.tokenizers import MistralTokenizer
+from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerV1
 
-from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder,
-                                                     FIMEncoderManager)
+from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.transformers_utils.tokenizers import MistralTokenizer
 
 
-@FIMEncoderManager.register_module("mistral")
 class MistralFIMEncoder(FIMEncoder):
 
     def __init__(self, tokenizer: AnyTokenizer):
         super().__init__(tokenizer)
 
-        if not isinstance(tokenizer, MistralTokenizer):
+        if not isinstance(tokenizer, MistralTokenizer) \
+            or isinstance(tokenizer.instruct, InstructTokenizerV1):
             raise ValueError(
                 "tokenizer incompatible with 'mistral' FIM encoder")
 
-    def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]:
-        return self.tokenizer.encode_with_suffix(prompt=prompt, suffix=suffix)
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        return self.tokenizer.encode_with_suffix(prefix=prefix, suffix=suffix)
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 74ac887fa78e8..2765b6424ae97 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -12,7 +12,7 @@
 from vllm.engine.protocol import EngineClient
 from vllm.entrypoints.logger import RequestLogger
 from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder,
-                                                     FIMEncoderManager)
+                                                     get_fim_encoder_lookup)
 # yapf conflicts with isort for this block
 # yapf: disable
 from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
@@ -121,8 +121,8 @@ def __init__(
         self.request_logger = request_logger
         self.return_tokens_as_token_ids = return_tokens_as_token_ids
 
-        self.fim_class: Optional[Callable[[AnyTokenizer], FIMEncoder]] = \
-            FIMEncoderManager.get_fim_encoder_class(fim_encoder)
+        self.get_fim_encoder: Optional[Callable[[AnyTokenizer], FIMEncoder]] = \
+            get_fim_encoder_lookup(fim_encoder)
 
     async def show_available_models(self) -> ModelList:
         """Show available models. Right now we only have one model."""
@@ -216,14 +216,13 @@ def _normalize_prompt_text_to_input(
         add_special_tokens: bool,
     ) -> TextTokensPrompt:
         if suffix:
-            if not (fim_class := self.fim_class):
+            if not (get_fim_encoder := self.get_fim_encoder):
                 raise ValueError("fim support must be enabled to use suffix")
             if truncate_prompt_tokens is not None:
                 raise ValueError(
                     "truncate_prompt_tokens is not supported with suffix")
-            fim_encoder = fim_class.for_tokenizer(  # type: ignore[attr-defined]
-                tokenizer)
-            input_ids = fim_encoder.encode_with_suffix(prompt=prompt,
+            fim_encoder = get_fim_encoder(tokenizer)
+            input_ids = fim_encoder.encode_with_suffix(prefix=prompt,
                                                        suffix=suffix)
         else:
             if truncate_prompt_tokens is None:
diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py
index fee78871b07a6..939f3590cb5b9 100644
--- a/vllm/transformers_utils/tokenizers/mistral.py
+++ b/vllm/transformers_utils/tokenizers/mistral.py
@@ -190,8 +190,8 @@ def encode(self, prompt: str) -> List[int]:
         # For chat completion use `apply_chat_template`
         return self.tokenizer.encode(prompt, bos=True, eos=False)
 
-    def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]:
-        fim = FIMRequest(prompt=prompt, suffix=suffix)
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        fim = FIMRequest(prefix=prefix, suffix=suffix)
         return self.mistral.encode_fim(fim)
 
     def apply_chat_template(self,

From 2cd52d6cc55219eb1318d562f589a217ea653549 Mon Sep 17 00:00:00 2001
From: Nick Hill 
Date: Tue, 22 Oct 2024 12:10:16 -0700
Subject: [PATCH 3/7] Fixes

---
 vllm/entrypoints/openai/api_server.py        |  2 +-
 vllm/entrypoints/openai/cli_args.py          |  2 +-
 vllm/entrypoints/openai/fim/__init__.py      | 69 ++++++++++++++++++++
 vllm/entrypoints/openai/fim/codellama_fim.py | 16 ++---
 vllm/entrypoints/openai/fim/fim_encoder.py   | 66 +------------------
 vllm/entrypoints/openai/serving_engine.py    |  3 +-
 6 files changed, 83 insertions(+), 75 deletions(-)

diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 976ad18b94897..4a0959ea149d2 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -33,7 +33,7 @@
 from vllm.entrypoints.logger import RequestLogger
 from vllm.entrypoints.openai.cli_args import (make_arg_parser,
                                               validate_parsed_serve_args)
-from vllm.entrypoints.openai.fim.fim_encoder import get_supported_fim_encoders
+from vllm.entrypoints.openai.fim import get_supported_fim_encoders
 # yapf conflicts with isort for this block
 # yapf: disable
 from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index ddbf13ec77abd..9559304d6ecdc 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -11,7 +11,7 @@
 
 from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
 from vllm.entrypoints.chat_utils import validate_chat_template
-from vllm.entrypoints.openai.fim.fim_encoder import get_supported_fim_encoders
+from vllm.entrypoints.openai.fim import get_supported_fim_encoders
 from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
                                                     PromptAdapterPath)
 from vllm.entrypoints.openai.tool_parsers import ToolParserManager
diff --git a/vllm/entrypoints/openai/fim/__init__.py b/vllm/entrypoints/openai/fim/__init__.py
index e69de29bb2d1d..d5f3431130165 100644
--- a/vllm/entrypoints/openai/fim/__init__.py
+++ b/vllm/entrypoints/openai/fim/__init__.py
@@ -0,0 +1,69 @@
+from functools import partial
+from inspect import isclass
+from typing import Callable, Dict, Iterable, Optional, Tuple, Type, Union
+
+from vllm.entrypoints.openai.fim.codellama_fim import CodeLlamaFIMEncoder
+from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder,
+                                                     StringTemplateFIMEncoder)
+from vllm.entrypoints.openai.fim.mistral_fim import MistralFIMEncoder
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+__all__ = [
+    "FIMEncoder", "get_supported_fim_encoders", "get_fim_encoder_lookup"
+]
+
+# Entries are either an FIMEncoder implementation class or
+# tuple of (template, special_tokens_list).
+_FIM_ENCODERS: Dict[str, Union[Type, Tuple[str, Iterable[str]]]] = {
+    "mistral":
+    MistralFIMEncoder,
+    "codellama":
+    CodeLlamaFIMEncoder,
+    "deepseek": (
+        "<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>",
+        ("<|fim▁begin|>", "<|fim▁hole|>", "<|fim▁end|>"),
+    ),
+    "starcoder": (
+        "{prefix}{suffix}",
+        ("", "", ""),
+    )
+}
+
+
+def get_supported_fim_encoders() -> Iterable[str]:
+    """Return set of supported FIM encoder types."""
+    return _FIM_ENCODERS.keys()
+
+
+def get_fim_encoder_lookup(
+        name: Optional[str]) -> Optional[Callable[[AnyTokenizer], FIMEncoder]]:
+    """
+    Get a function that returns a FIMEncoder instance for a given tokenizer.
+    Raise a KeyError exception if the name is not recognized.
+    """
+    if name is None:
+        return None
+
+    if (encoder := _FIM_ENCODERS.get(name)) is None:
+        raise ValueError(f"fim encoder '{name}' not recognized")
+
+    factory: Callable[[AnyTokenizer], FIMEncoder]
+    if isclass(encoder):
+        assert issubclass(encoder, FIMEncoder)
+        factory = encoder
+    else:
+        assert isinstance(encoder, tuple)
+        template, special_tokens = encoder
+        factory = partial(StringTemplateFIMEncoder,
+                          name=name,
+                          template=template,
+                          special_tokens=special_tokens)
+
+    def for_tokenizer(tokenizer: AnyTokenizer) -> FIMEncoder:
+        fim_encoder = getattr(tokenizer, "fim_encoder", None)
+        if fim_encoder is None:
+            fim_encoder = factory(tokenizer)
+            tokenizer.fim_encoder = fim_encoder  # type: ignore[union-attr]
+        return fim_encoder
+
+    return for_tokenizer
diff --git a/vllm/entrypoints/openai/fim/codellama_fim.py b/vllm/entrypoints/openai/fim/codellama_fim.py
index ecbb212b45632..224d34d2288d5 100644
--- a/vllm/entrypoints/openai/fim/codellama_fim.py
+++ b/vllm/entrypoints/openai/fim/codellama_fim.py
@@ -31,11 +31,11 @@ def __init__(self, tokenizer: AnyTokenizer):
                 "tokenizer incompatible with 'codellama' FIM encoder")
 
     def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
-        return ([self.bos_id, self.prefix_id] +
-                self.tokenizer(prefix, add_special_tokens=False) +
-                [self.suffix_id] + self._encode_infilling(suffix) +
-                [self.middle_id])
-
-    def _encode_infilling(self, s: str) -> List[int]:
-        """Encode a string without an implicit leading space."""
-        return self.tokenizer("☺" + s, add_special_tokens=False)[2:]
+        prefix_tokens = self.tokenizer(prefix,
+                                       add_special_tokens=False).input_ids
+        # Encode a string without an implicit leading space.
+        suffix_tokens = self.tokenizer("☺" + suffix,
+                                       add_special_tokens=False).input_ids[2:]
+
+        return ([self.bos_id, self.prefix_id] + prefix_tokens[self.suffix_id] +
+                suffix_tokens + [self.middle_id])
diff --git a/vllm/entrypoints/openai/fim/fim_encoder.py b/vllm/entrypoints/openai/fim/fim_encoder.py
index 7cee051dfcfb6..9b6f27a7a1e36 100644
--- a/vllm/entrypoints/openai/fim/fim_encoder.py
+++ b/vllm/entrypoints/openai/fim/fim_encoder.py
@@ -1,29 +1,8 @@
 from abc import ABC, abstractmethod
-from functools import partial
-from inspect import isclass
-from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
+from typing import Iterable, List, Optional
 
-from vllm.entrypoints.openai.fim.codellama_fim import CodeLlamaFIMEncoder
-from vllm.entrypoints.openai.fim.mistral_fim import MistralFIMEncoder
 from vllm.transformers_utils.tokenizer import AnyTokenizer
 
-# Entries are either an FIMEncoder implementation class or
-# tuple of (template, special_tokens_list).
-_FIM_ENCODERS: Dict[str, Union[Type, Tuple[str, Iterable[str]]]] = {
-    "mistral":
-    MistralFIMEncoder,
-    "codellama":
-    CodeLlamaFIMEncoder,
-    "deepseek": (
-        "<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>",
-        ("<|fim▁begin|>", "<|fim▁hole|>", "<|fim▁end|>"),
-    ),
-    "starcoder": (
-        "{prefix}{suffix}",
-        ("", "", ""),
-    )
-}
-
 
 class FIMEncoder(ABC):
     """
@@ -49,8 +28,8 @@ class StringTemplateFIMEncoder(FIMEncoder):
 
     def __init__(
         self,
-        name: str,
         tokenizer: AnyTokenizer,
+        name: str,
         template: str,
         special_tokens: Optional[Iterable[str]] = None,
     ):
@@ -70,43 +49,4 @@ def __init__(
 
     def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
         prompt = self.template.format(prefix=prefix, suffix=suffix)
-        return self.tokenizer(prompt, add_special_tokens=False)
-
-
-def get_supported_fim_encoders() -> Iterable[str]:
-    """Return set of supported FIM encoder types."""
-    return _FIM_ENCODERS.keys()
-
-
-def get_fim_encoder_lookup(
-        name: Optional[str]) -> Optional[Callable[[AnyTokenizer], FIMEncoder]]:
-    """
-    Get a function that returns a FIMEncoder instance for a given tokenizer.
-    Raise a KeyError exception if the name is not recognized.
-    """
-    if name is None:
-        return None
-
-    if (encoder := _FIM_ENCODERS.get(name)) is None:
-        raise KeyError(f"fim encoder '{name}' not recognized")
-
-    factory: Callable[[AnyTokenizer], FIMEncoder]
-    if isclass(encoder):
-        assert issubclass(encoder, FIMEncoder)
-        factory = encoder
-    else:
-        assert isinstance(encoder, tuple)
-        template, special_tokens = encoder
-        factory = partial(StringTemplateFIMEncoder,
-                          name=name,
-                          template=template,
-                          special_tokens=special_tokens)
-
-    def for_tokenizer(tokenizer: AnyTokenizer) -> FIMEncoder:
-        fim_encoder = getattr(tokenizer, "fim_encoder", None)
-        if fim_encoder is None:
-            fim_encoder = factory(tokenizer)
-            tokenizer.fim_encoder = fim_encoder  # type: ignore[union-attr]
-        return fim_encoder
-
-    return for_tokenizer
+        return self.tokenizer(prompt, add_special_tokens=False).input_ids
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 2765b6424ae97..1a313a5c63fdc 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -11,8 +11,7 @@
 from vllm.config import ModelConfig
 from vllm.engine.protocol import EngineClient
 from vllm.entrypoints.logger import RequestLogger
-from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder,
-                                                     get_fim_encoder_lookup)
+from vllm.entrypoints.openai.fim import FIMEncoder, get_fim_encoder_lookup
 # yapf conflicts with isort for this block
 # yapf: disable
 from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,

From ad877c00f19242bb10999870cac097ee7fd706a7 Mon Sep 17 00:00:00 2001
From: Nick Hill 
Date: Tue, 22 Oct 2024 13:39:43 -0700
Subject: [PATCH 4/7] Mistral fixes

---
 vllm/entrypoints/openai/fim/mistral_fim.py    | 5 +++--
 vllm/transformers_utils/tokenizers/mistral.py | 4 ++--
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/vllm/entrypoints/openai/fim/mistral_fim.py b/vllm/entrypoints/openai/fim/mistral_fim.py
index a18e848c67fa9..21fd1cca9e217 100644
--- a/vllm/entrypoints/openai/fim/mistral_fim.py
+++ b/vllm/entrypoints/openai/fim/mistral_fim.py
@@ -1,6 +1,6 @@
 from typing import List
 
-from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerV1
+from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerV2
 
 from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder
 from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -12,8 +12,9 @@ class MistralFIMEncoder(FIMEncoder):
     def __init__(self, tokenizer: AnyTokenizer):
         super().__init__(tokenizer)
 
+        # InstructTokenizerV3 is a subclass of InstructTokenizerV2
         if not isinstance(tokenizer, MistralTokenizer) \
-            or isinstance(tokenizer.instruct, InstructTokenizerV1):
+            or not isinstance(tokenizer.instruct, InstructTokenizerV2):
             raise ValueError(
                 "tokenizer incompatible with 'mistral' FIM encoder")
 
diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py
index 939f3590cb5b9..17fe5fc544921 100644
--- a/vllm/transformers_utils/tokenizers/mistral.py
+++ b/vllm/transformers_utils/tokenizers/mistral.py
@@ -191,8 +191,8 @@ def encode(self, prompt: str) -> List[int]:
         return self.tokenizer.encode(prompt, bos=True, eos=False)
 
     def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
-        fim = FIMRequest(prefix=prefix, suffix=suffix)
-        return self.mistral.encode_fim(fim)
+        fim = FIMRequest(prompt=prefix, suffix=suffix)
+        return self.mistral.encode_fim(fim).tokens
 
     def apply_chat_template(self,
                             messages: List["ChatCompletionMessageParam"],

From 64eacabc34242d092996f0bad9c3765e413506b3 Mon Sep 17 00:00:00 2001
From: Thomas Bouamoud 
Date: Sat, 15 Feb 2025 18:26:28 +0000
Subject: [PATCH 5/7] static analysis warning

---
 vllm/entrypoints/openai/serving_engine.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 386be2ba39c40..52144f43253b9 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -465,6 +465,7 @@ async def _preprocess_chat(
                 request,
                 tokenizer,
                 request_prompt,
+                suffix=None,
                 truncate_prompt_tokens=truncate_prompt_tokens,
                 add_special_tokens=add_special_tokens,
             )

From 24c7d0f0927f3fc3688f48251d5bb988cc9632e3 Mon Sep 17 00:00:00 2001
From: Nick Hill 
Date: Tue, 18 Feb 2025 09:41:57 -0800
Subject: [PATCH 6/7] Fix linting

Signed-off-by: Nick Hill 
---
 vllm/entrypoints/openai/api_server.py         |  1 -
 vllm/entrypoints/openai/cli_args.py           |  2 +-
 vllm/entrypoints/openai/fim/__init__.py       |  1 +
 vllm/entrypoints/openai/fim/codellama_fim.py  |  1 +
 vllm/entrypoints/openai/fim/fim_encoder.py    |  1 +
 vllm/entrypoints/openai/fim/mistral_fim.py    |  1 +
 vllm/entrypoints/openai/serving_engine.py     | 15 ++++++---------
 vllm/transformers_utils/tokenizers/mistral.py |  2 --
 8 files changed, 11 insertions(+), 13 deletions(-)

diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index bfd96cb858874..9a4a7593f4ef9 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -910,7 +910,6 @@ async def run_server(args, **uvicorn_kwargs) -> None:
                 f"invalid FIM encoder: {args.fim} "
                 f"(chose from {{ {','.join(valid_fim_encoders)} }})")
 
-
     valid_tool_parses = ToolParserManager.tool_parsers.keys()
     if args.enable_auto_tool_choice \
         and args.tool_call_parser not in valid_tool_parses:
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index 1e65dde2f1d5b..79d08ebf0f7b6 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -13,8 +13,8 @@
 from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
 from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
                                          validate_chat_template)
-from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
 from vllm.entrypoints.openai.fim import get_supported_fim_encoders
+from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
 from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
                                                     PromptAdapterPath)
 from vllm.entrypoints.openai.tool_parsers import ToolParserManager
diff --git a/vllm/entrypoints/openai/fim/__init__.py b/vllm/entrypoints/openai/fim/__init__.py
index d5f3431130165..60616ba74fb1d 100644
--- a/vllm/entrypoints/openai/fim/__init__.py
+++ b/vllm/entrypoints/openai/fim/__init__.py
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: Apache-2.0
 from functools import partial
 from inspect import isclass
 from typing import Callable, Dict, Iterable, Optional, Tuple, Type, Union
diff --git a/vllm/entrypoints/openai/fim/codellama_fim.py b/vllm/entrypoints/openai/fim/codellama_fim.py
index 224d34d2288d5..c783873961ac9 100644
--- a/vllm/entrypoints/openai/fim/codellama_fim.py
+++ b/vllm/entrypoints/openai/fim/codellama_fim.py
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: Apache-2.0
 from typing import List
 
 from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder
diff --git a/vllm/entrypoints/openai/fim/fim_encoder.py b/vllm/entrypoints/openai/fim/fim_encoder.py
index 9b6f27a7a1e36..afcd2fce936fb 100644
--- a/vllm/entrypoints/openai/fim/fim_encoder.py
+++ b/vllm/entrypoints/openai/fim/fim_encoder.py
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: Apache-2.0
 from abc import ABC, abstractmethod
 from typing import Iterable, List, Optional
 
diff --git a/vllm/entrypoints/openai/fim/mistral_fim.py b/vllm/entrypoints/openai/fim/mistral_fim.py
index 21fd1cca9e217..9bfc9bff87908 100644
--- a/vllm/entrypoints/openai/fim/mistral_fim.py
+++ b/vllm/entrypoints/openai/fim/mistral_fim.py
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: Apache-2.0
 from typing import List
 
 from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerV2
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 52144f43253b9..c2e2ae9908689 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -189,7 +189,8 @@ def _normalize_prompt_text_to_input(
 
         else:
             if truncate_prompt_tokens is None:
-                encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
+                encoded = tokenizer(prompt,
+                                    add_special_tokens=add_special_tokens)
             else:
                 encoded = tokenizer(prompt,
                                     add_special_tokens=add_special_tokens,
@@ -289,12 +290,10 @@ def _tokenize_prompt_input(
         return next(
             self._tokenize_prompt_inputs(
                 request,
-                tokenizer,
-                [prompt_input],
+                tokenizer, [prompt_input],
                 truncate_prompt_tokens=truncate_prompt_tokens,
                 add_special_tokens=add_special_tokens,
-                suffix=suffix
-            ))
+                suffix=suffix))
 
     def _tokenize_prompt_inputs(
         self,
@@ -357,8 +356,7 @@ def _tokenize_prompt_input_or_inputs(
                 prompt=prompt_input["content"],
                 truncate_prompt_tokens=truncate_prompt_tokens,
                 add_special_tokens=add_special_tokens,
-                suffix=suffix)
-            if prompt_input["is_tokens"] is False else
+                suffix=suffix) if prompt_input["is_tokens"] is False else
             self._normalize_prompt_tokens_to_input(
                 request,
                 tokenizer,
@@ -382,8 +380,7 @@ async def _preprocess_completion(
             input_or_inputs,
             truncate_prompt_tokens=truncate_prompt_tokens,
             add_special_tokens=add_special_tokens,
-            suffix=suffix
-        )
+            suffix=suffix)
 
         engine_prompts = [
             TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py
index f22cd2b57e264..cf2351a3647ac 100644
--- a/vllm/transformers_utils/tokenizers/mistral.py
+++ b/vllm/transformers_utils/tokenizers/mistral.py
@@ -21,8 +21,6 @@
     from mistral_common.tokens.instruct.request import FIMRequest
     from mistral_common.tokens.tokenizers.mistral import (
         MistralTokenizer as PublicMistralTokenizer)
-    from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
-                                                         Tekkenizer)
 
     from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
 

From d5d97d8212eef253bd4a160ea9dab3f1dd991c68 Mon Sep 17 00:00:00 2001
From: Nick Hill 
Date: Tue, 18 Feb 2025 09:52:13 -0800
Subject: [PATCH 7/7] Fix mypy

Signed-off-by: Nick Hill 
---
 vllm/entrypoints/openai/fim/mistral_fim.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/vllm/entrypoints/openai/fim/mistral_fim.py b/vllm/entrypoints/openai/fim/mistral_fim.py
index 9bfc9bff87908..115d739192f69 100644
--- a/vllm/entrypoints/openai/fim/mistral_fim.py
+++ b/vllm/entrypoints/openai/fim/mistral_fim.py
@@ -20,4 +20,5 @@ def __init__(self, tokenizer: AnyTokenizer):
                 "tokenizer incompatible with 'mistral' FIM encoder")
 
     def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        assert isinstance(self.tokenizer, MistralTokenizer)
         return self.tokenizer.encode_with_suffix(prefix=prefix, suffix=suffix)