From c4886c31f985123bdff2b459f1c8c71bd607da03 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 21 Oct 2024 16:43:40 -0700 Subject: [PATCH] Simplify registration --- vllm/entrypoints/openai/api_server.py | 6 +- vllm/entrypoints/openai/cli_args.py | 4 +- vllm/entrypoints/openai/fim/codellama_fim.py | 41 +++++ vllm/entrypoints/openai/fim/fim_encoder.py | 141 +++++++++--------- vllm/entrypoints/openai/fim/mistral_fim.py | 16 +- vllm/entrypoints/openai/serving_engine.py | 6 +- vllm/transformers_utils/tokenizers/mistral.py | 4 +- 7 files changed, 133 insertions(+), 85 deletions(-) create mode 100644 vllm/entrypoints/openai/fim/codellama_fim.py diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1ca9d25411af7..976ad18b94897 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,7 +14,6 @@ from typing import AsyncIterator, Set import uvloop -from entrypoints.openai.fim.fim_encoder import FIMEncoderManager from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware @@ -34,6 +33,7 @@ from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) +from vllm.entrypoints.openai.fim.fim_encoder import get_supported_fim_encoders # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -541,11 +541,11 @@ async def run_server(args, **uvicorn_kwargs) -> None: f"(chose from {{ {','.join(valid_tool_parsers.keys())} }})") if args.fim is not None: - valid_fim_encoders = FIMEncoderManager.fim_encoders + valid_fim_encoders = get_supported_fim_encoders() if args.fim not in valid_fim_encoders: raise KeyError( f"invalid FIM encoder: {args.fim} " - f"(chose from {{ {','.join(valid_fim_encoders.keys())} }})") + f"(chose from {{ {','.join(valid_fim_encoders)} }})") # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 31bda24262415..ddbf13ec77abd 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -11,7 +11,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import validate_chat_template -from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoderManager +from vllm.entrypoints.openai.fim.fim_encoder import get_supported_fim_encoders from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -214,7 +214,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: " into OpenAI API format, the name register in this plugin can be used " "in --tool-call-parser.") - valid_fim_encoders = FIMEncoderManager.fim_encoders.keys() + valid_fim_encoders = get_supported_fim_encoders() parser.add_argument( "--fim", type=str, diff --git a/vllm/entrypoints/openai/fim/codellama_fim.py b/vllm/entrypoints/openai/fim/codellama_fim.py new file mode 100644 index 0000000000000..ecbb212b45632 --- /dev/null +++ b/vllm/entrypoints/openai/fim/codellama_fim.py @@ -0,0 +1,41 @@ +from typing import List + +from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class CodeLlamaFIMEncoder(FIMEncoder): + """ + FIM Encoder for Meta CodeLlama models + + Adapted from https://github.com/meta-llama/codellama/blob/e81b597e44dbecc2a0dedb9949fdf84adfc22395/llama/generation.py#L474 + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not hasattr(tokenizer, "convert_tokens_to_ids"): + raise ValueError( + "tokenizer incompatible with 'codellama' FIM encoder") + + self.bos_id = tokenizer.convert_tokens_to_ids("") + self.prefix_id = tokenizer.convert_tokens_to_ids("▁
")
+        self.suffix_id = tokenizer.convert_tokens_to_ids("▁")
+        self.middle_id = tokenizer.convert_tokens_to_ids("▁")
+
+        unk_token_id = getattr(tokenizer, "unk_token_id", None)
+        if any(tid in
+               {self.bos_id, self.prefix_id, self.suffix_id, self.middle_id}
+               for tid in (None, unk_token_id)):
+            raise ValueError(
+                "tokenizer incompatible with 'codellama' FIM encoder")
+
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        return ([self.bos_id, self.prefix_id] +
+                self.tokenizer(prefix, add_special_tokens=False) +
+                [self.suffix_id] + self._encode_infilling(suffix) +
+                [self.middle_id])
+
+    def _encode_infilling(self, s: str) -> List[int]:
+        """Encode a string without an implicit leading space."""
+        return self.tokenizer("☺" + s, add_special_tokens=False)[2:]
diff --git a/vllm/entrypoints/openai/fim/fim_encoder.py b/vllm/entrypoints/openai/fim/fim_encoder.py
index 2593936f21ebd..867db1c1385cb 100644
--- a/vllm/entrypoints/openai/fim/fim_encoder.py
+++ b/vllm/entrypoints/openai/fim/fim_encoder.py
@@ -1,8 +1,28 @@
 from abc import ABC, abstractmethod
-from typing import Callable, Dict, List, Optional, Type, Union
+from functools import partial
+from inspect import isclass
+from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
 
+from vllm.entrypoints.openai.fim.codellama_fim import CodeLlamaFIMEncoder
+from vllm.entrypoints.openai.fim.mistral_fim import MistralFIMEncoder
 from vllm.transformers_utils.tokenizer import AnyTokenizer
-from vllm.utils import is_list_of
+
+# Entries are either an FIMEncoder implementation class or
+# tuple of (template, special_tokens_list).
+_FIM_ENCODERS: Dict[str, Union[Type, Tuple[str, Iterable[str]]]] = {
+    "mistral":
+    MistralFIMEncoder,
+    "codellama":
+    CodeLlamaFIMEncoder,
+    "deepseek": (
+        "<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>",
+        ("<|fim▁begin|>", "<|fim▁hole|>", "<|fim▁end|>"),
+    ),
+    "starcoder": (
+        "prefixsuffix",
+        ("", "", ""),
+    )
+}
 
 
 class FIMEncoder(ABC):
@@ -15,7 +35,7 @@ def __init__(self, tokenizer: AnyTokenizer):
         self.tokenizer = tokenizer
 
     @abstractmethod
-    def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]:
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
         """
         Encode the provided prompt prefix and suffix
         to a list of token ids
@@ -31,73 +51,60 @@ def for_tokenizer(cls: Type, tokenizer: AnyTokenizer) -> "FIMEncoder":
         return fim_encoder
 
 
-class FIMEncoderManager:
-    fim_encoders: Dict[str, Type] = {}
+class StringTemplateFIMEncoder(FIMEncoder):
+    """FIMEncoder implementation using a simple string template
+    with prefix and suffix variables."""
 
-    @classmethod
-    def get_fim_encoder_class(cls, name: Optional[str]) -> Optional[Type]:
-        """
-        Get FIM encoder by name which is registered by `register_module`.
+    def __init__(
+        self,
+        name: str,
+        tokenizer: AnyTokenizer,
+        template: str,
+        special_tokens: Optional[Iterable[str]] = None,
+    ):
+        super().__init__(tokenizer)
 
-        Raise a KeyError exception if the name is not registered.
-        """
-        if name is None:
-            return None
+        if not hasattr(tokenizer, "convert_tokens_to_ids"):
+            raise ValueError(
+                "tokenizer incompatible with 'codellama' FIM encoder")
 
-        if (encoder := cls.fim_encoders.get(name)) is None:
-            raise KeyError(f"fim encoder '{name}' not recognized")
+        unk_token_id = getattr(tokenizer, "unk_token_id", None)
+        for special_token in special_tokens or ():
+            token_id = tokenizer.convert_tokens_to_ids(special_token)
+            if token_id is None or token_id == unk_token_id:
+                raise ValueError(
+                    f"tokenizer incompatible with '{name}' FIM encoder")
+        self.template = template
 
-        return encoder
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        prompt = self.template.format(prefix=prefix, suffix=suffix)
+        return self.tokenizer(prompt, add_special_tokens=False)
 
-    @classmethod
-    def _register_module(cls,
-                         module: Type,
-                         module_name: Optional[Union[str, List[str]]] = None,
-                         force: bool = True) -> None:
-        if not issubclass(module, FIMEncoder):
-            raise TypeError(
-                f'module must be subclass of FIMEncoder, but got {type(module)}'
-            )
-        if module_name is None:
-            module_name = module.__name__
-        if isinstance(module_name, str):
-            module_name = [module_name]
-        for name in module_name:
-            if not force and (exist_module := cls.fim_encoders.get(name)
-                              is not None):
-                raise KeyError(f'{name} is already registered '
-                               f'at {exist_module.__module__}')
-            cls.fim_encoders[name] = module
 
-    @classmethod
-    def register_module(
-            cls,
-            name: Optional[Union[str, List[str]]] = None,
-            force: bool = True,
-            module: Union[Type, None] = None) -> Union[type, Callable]:
-        """
-        Register module with the given name or name list. it can be used as a
-        decoder(with module as None) or normal function(with module as not
-        None).
-        """
-        if not isinstance(force, bool):
-            raise TypeError(f'force must be a boolean, but got {type(force)}')
-
-        # raise the error ahead of time
-        if not (name is None or isinstance(name, str)
-                or is_list_of(name, str)):
-            raise TypeError(
-                'name must be None, an instance of str, or a sequence of str, '
-                f'but got {type(name)}')
-
-        # use it as a normal method: x.register_module(module=SomeClass)
-        if module is not None:
-            cls._register_module(module=module, module_name=name, force=force)
-            return module
-
-        # use it as a decorator: @x.register_module()
-        def _register(module):
-            cls._register_module(module=module, module_name=name, force=force)
-            return module
-
-        return _register
+def get_supported_fim_encoders() -> Iterable[str]:
+    """Return set of supported FIM encoder types."""
+    return _FIM_ENCODERS.keys()
+
+
+def get_fim_encoder_factory(
+        name: Optional[str]) -> Optional[Callable[[AnyTokenizer], FIMEncoder]]:
+    """
+    Get FIM encoder by name.
+    Raise a KeyError exception if the name is not recognized.
+    """
+    if name is None:
+        return None
+
+    if (encoder := _FIM_ENCODERS.get(name)) is None:
+        raise KeyError(f"fim encoder '{name}' not recognized")
+
+    if isclass(encoder):
+        assert issubclass(encoder, FIMEncoder)
+        return encoder  # type: ignore[return-value]
+
+    # assert isinstance(encoder, Tuple[str, Iterable[str]])
+    template, special_tokens = encoder
+    return partial(StringTemplateFIMEncoder,
+                   name=name,
+                   template=template,
+                   special_tokens=special_tokens)
diff --git a/vllm/entrypoints/openai/fim/mistral_fim.py b/vllm/entrypoints/openai/fim/mistral_fim.py
index a545338de683d..a18e848c67fa9 100644
--- a/vllm/entrypoints/openai/fim/mistral_fim.py
+++ b/vllm/entrypoints/openai/fim/mistral_fim.py
@@ -1,21 +1,21 @@
 from typing import List
 
-from transformers_utils.tokenizer import AnyTokenizer
-from transformers_utils.tokenizers import MistralTokenizer
+from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerV1
 
-from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder,
-                                                     FIMEncoderManager)
+from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.transformers_utils.tokenizers import MistralTokenizer
 
 
-@FIMEncoderManager.register_module("mistral")
 class MistralFIMEncoder(FIMEncoder):
 
     def __init__(self, tokenizer: AnyTokenizer):
         super().__init__(tokenizer)
 
-        if not isinstance(tokenizer, MistralTokenizer):
+        if not isinstance(tokenizer, MistralTokenizer) \
+            or isinstance(tokenizer.instruct, InstructTokenizerV1):
             raise ValueError(
                 "tokenizer incompatible with 'mistral' FIM encoder")
 
-    def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]:
-        return self.tokenizer.encode_with_suffix(prompt=prompt, suffix=suffix)
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        return self.tokenizer.encode_with_suffix(prefix=prefix, suffix=suffix)
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 74ac887fa78e8..dde2d5c920d9e 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -12,7 +12,7 @@
 from vllm.engine.protocol import EngineClient
 from vllm.entrypoints.logger import RequestLogger
 from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder,
-                                                     FIMEncoderManager)
+                                                     get_fim_encoder_factory)
 # yapf conflicts with isort for this block
 # yapf: disable
 from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
@@ -122,7 +122,7 @@ def __init__(
         self.return_tokens_as_token_ids = return_tokens_as_token_ids
 
         self.fim_class: Optional[Callable[[AnyTokenizer], FIMEncoder]] = \
-            FIMEncoderManager.get_fim_encoder_class(fim_encoder)
+            get_fim_encoder_factory(fim_encoder)
 
     async def show_available_models(self) -> ModelList:
         """Show available models. Right now we only have one model."""
@@ -223,7 +223,7 @@ def _normalize_prompt_text_to_input(
                     "truncate_prompt_tokens is not supported with suffix")
             fim_encoder = fim_class.for_tokenizer(  # type: ignore[attr-defined]
                 tokenizer)
-            input_ids = fim_encoder.encode_with_suffix(prompt=prompt,
+            input_ids = fim_encoder.encode_with_suffix(prefix=prompt,
                                                        suffix=suffix)
         else:
             if truncate_prompt_tokens is None:
diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py
index fee78871b07a6..939f3590cb5b9 100644
--- a/vllm/transformers_utils/tokenizers/mistral.py
+++ b/vllm/transformers_utils/tokenizers/mistral.py
@@ -190,8 +190,8 @@ def encode(self, prompt: str) -> List[int]:
         # For chat completion use `apply_chat_template`
         return self.tokenizer.encode(prompt, bos=True, eos=False)
 
-    def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]:
-        fim = FIMRequest(prompt=prompt, suffix=suffix)
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        fim = FIMRequest(prefix=prefix, suffix=suffix)
         return self.mistral.encode_fim(fim)
 
     def apply_chat_template(self,