Skip to content

Commit

Permalink
Simplify registration
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Oct 22, 2024
1 parent f57746a commit c4886c3
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 85 deletions.
6 changes: 3 additions & 3 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import AsyncIterator, Set

import uvloop
from entrypoints.openai.fim.fim_encoder import FIMEncoderManager
from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -34,6 +33,7 @@
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.entrypoints.openai.fim.fim_encoder import get_supported_fim_encoders
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
Expand Down Expand Up @@ -541,11 +541,11 @@ async def run_server(args, **uvicorn_kwargs) -> None:
f"(chose from {{ {','.join(valid_tool_parsers.keys())} }})")

if args.fim is not None:
valid_fim_encoders = FIMEncoderManager.fim_encoders
valid_fim_encoders = get_supported_fim_encoders()
if args.fim not in valid_fim_encoders:
raise KeyError(
f"invalid FIM encoder: {args.fim} "
f"(chose from {{ {','.join(valid_fim_encoders.keys())} }})")
f"(chose from {{ {','.join(valid_fim_encoders)} }})")

# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import validate_chat_template
from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoderManager
from vllm.entrypoints.openai.fim.fim_encoder import get_supported_fim_encoders
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
Expand Down Expand Up @@ -214,7 +214,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser.")

valid_fim_encoders = FIMEncoderManager.fim_encoders.keys()
valid_fim_encoders = get_supported_fim_encoders()
parser.add_argument(
"--fim",
type=str,
Expand Down
41 changes: 41 additions & 0 deletions vllm/entrypoints/openai/fim/codellama_fim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import List

from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder
from vllm.transformers_utils.tokenizer import AnyTokenizer


class CodeLlamaFIMEncoder(FIMEncoder):
"""
FIM Encoder for Meta CodeLlama models
Adapted from https://github.com/meta-llama/codellama/blob/e81b597e44dbecc2a0dedb9949fdf84adfc22395/llama/generation.py#L474
"""

def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)

if not hasattr(tokenizer, "convert_tokens_to_ids"):
raise ValueError(
"tokenizer incompatible with 'codellama' FIM encoder")

self.bos_id = tokenizer.convert_tokens_to_ids("<s>")
self.prefix_id = tokenizer.convert_tokens_to_ids("▁<PRE>")
self.suffix_id = tokenizer.convert_tokens_to_ids("▁<SUF>")
self.middle_id = tokenizer.convert_tokens_to_ids("▁<MID>")

unk_token_id = getattr(tokenizer, "unk_token_id", None)
if any(tid in
{self.bos_id, self.prefix_id, self.suffix_id, self.middle_id}
for tid in (None, unk_token_id)):
raise ValueError(
"tokenizer incompatible with 'codellama' FIM encoder")

def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
return ([self.bos_id, self.prefix_id] +
self.tokenizer(prefix, add_special_tokens=False) +
[self.suffix_id] + self._encode_infilling(suffix) +
[self.middle_id])

def _encode_infilling(self, s: str) -> List[int]:
"""Encode a string without an implicit leading space."""
return self.tokenizer("☺" + s, add_special_tokens=False)[2:]
141 changes: 74 additions & 67 deletions vllm/entrypoints/openai/fim/fim_encoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Optional, Type, Union
from functools import partial
from inspect import isclass
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type, Union

from vllm.entrypoints.openai.fim.codellama_fim import CodeLlamaFIMEncoder
from vllm.entrypoints.openai.fim.mistral_fim import MistralFIMEncoder
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import is_list_of

# Entries are either an FIMEncoder implementation class or
# tuple of (template, special_tokens_list).
_FIM_ENCODERS: Dict[str, Union[Type, Tuple[str, Iterable[str]]]] = {
"mistral":
MistralFIMEncoder,
"codellama":
CodeLlamaFIMEncoder,
"deepseek": (
"<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>",
("<|fim▁begin|>", "<|fim▁hole|>", "<|fim▁end|>"),
),
"starcoder": (
"<fim_prefix>prefix<fim_suffix>suffix<fim_middle>",
("<fim_prefix>", "<fim_suffix>", "<fim_middle>"),
)
}


class FIMEncoder(ABC):
Expand All @@ -15,7 +35,7 @@ def __init__(self, tokenizer: AnyTokenizer):
self.tokenizer = tokenizer

@abstractmethod
def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]:
def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
"""
Encode the provided prompt prefix and suffix
to a list of token ids
Expand All @@ -31,73 +51,60 @@ def for_tokenizer(cls: Type, tokenizer: AnyTokenizer) -> "FIMEncoder":
return fim_encoder


class FIMEncoderManager:
fim_encoders: Dict[str, Type] = {}
class StringTemplateFIMEncoder(FIMEncoder):
"""FIMEncoder implementation using a simple string template
with prefix and suffix variables."""

@classmethod
def get_fim_encoder_class(cls, name: Optional[str]) -> Optional[Type]:
"""
Get FIM encoder by name which is registered by `register_module`.
def __init__(
self,
name: str,
tokenizer: AnyTokenizer,
template: str,
special_tokens: Optional[Iterable[str]] = None,
):
super().__init__(tokenizer)

Raise a KeyError exception if the name is not registered.
"""
if name is None:
return None
if not hasattr(tokenizer, "convert_tokens_to_ids"):
raise ValueError(
"tokenizer incompatible with 'codellama' FIM encoder")

if (encoder := cls.fim_encoders.get(name)) is None:
raise KeyError(f"fim encoder '{name}' not recognized")
unk_token_id = getattr(tokenizer, "unk_token_id", None)
for special_token in special_tokens or ():
token_id = tokenizer.convert_tokens_to_ids(special_token)
if token_id is None or token_id == unk_token_id:
raise ValueError(
f"tokenizer incompatible with '{name}' FIM encoder")
self.template = template

return encoder
def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
prompt = self.template.format(prefix=prefix, suffix=suffix)
return self.tokenizer(prompt, add_special_tokens=False)

@classmethod
def _register_module(cls,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = True) -> None:
if not issubclass(module, FIMEncoder):
raise TypeError(
f'module must be subclass of FIMEncoder, but got {type(module)}'
)
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and (exist_module := cls.fim_encoders.get(name)
is not None):
raise KeyError(f'{name} is already registered '
f'at {exist_module.__module__}')
cls.fim_encoders[name] = module

@classmethod
def register_module(
cls,
name: Optional[Union[str, List[str]]] = None,
force: bool = True,
module: Union[Type, None] = None) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')

# raise the error ahead of time
if not (name is None or isinstance(name, str)
or is_list_of(name, str)):
raise TypeError(
'name must be None, an instance of str, or a sequence of str, '
f'but got {type(name)}')

# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module

# use it as a decorator: @x.register_module()
def _register(module):
cls._register_module(module=module, module_name=name, force=force)
return module

return _register
def get_supported_fim_encoders() -> Iterable[str]:
"""Return set of supported FIM encoder types."""
return _FIM_ENCODERS.keys()


def get_fim_encoder_factory(
name: Optional[str]) -> Optional[Callable[[AnyTokenizer], FIMEncoder]]:
"""
Get FIM encoder by name.
Raise a KeyError exception if the name is not recognized.
"""
if name is None:
return None

if (encoder := _FIM_ENCODERS.get(name)) is None:
raise KeyError(f"fim encoder '{name}' not recognized")

if isclass(encoder):
assert issubclass(encoder, FIMEncoder)
return encoder # type: ignore[return-value]

# assert isinstance(encoder, Tuple[str, Iterable[str]])
template, special_tokens = encoder
return partial(StringTemplateFIMEncoder,
name=name,
template=template,
special_tokens=special_tokens)
16 changes: 8 additions & 8 deletions vllm/entrypoints/openai/fim/mistral_fim.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from typing import List

from transformers_utils.tokenizer import AnyTokenizer
from transformers_utils.tokenizers import MistralTokenizer
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerV1

from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder,
FIMEncoderManager)
from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers import MistralTokenizer


@FIMEncoderManager.register_module("mistral")
class MistralFIMEncoder(FIMEncoder):

def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)

if not isinstance(tokenizer, MistralTokenizer):
if not isinstance(tokenizer, MistralTokenizer) \
or isinstance(tokenizer.instruct, InstructTokenizerV1):
raise ValueError(
"tokenizer incompatible with 'mistral' FIM encoder")

def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]:
return self.tokenizer.encode_with_suffix(prompt=prompt, suffix=suffix)
def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
return self.tokenizer.encode_with_suffix(prefix=prefix, suffix=suffix)
6 changes: 3 additions & 3 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder,
FIMEncoderManager)
get_fim_encoder_factory)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(
self.return_tokens_as_token_ids = return_tokens_as_token_ids

self.fim_class: Optional[Callable[[AnyTokenizer], FIMEncoder]] = \
FIMEncoderManager.get_fim_encoder_class(fim_encoder)
get_fim_encoder_factory(fim_encoder)

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
Expand Down Expand Up @@ -223,7 +223,7 @@ def _normalize_prompt_text_to_input(
"truncate_prompt_tokens is not supported with suffix")
fim_encoder = fim_class.for_tokenizer( # type: ignore[attr-defined]
tokenizer)
input_ids = fim_encoder.encode_with_suffix(prompt=prompt,
input_ids = fim_encoder.encode_with_suffix(prefix=prompt,
suffix=suffix)
else:
if truncate_prompt_tokens is None:
Expand Down
4 changes: 2 additions & 2 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def encode(self, prompt: str) -> List[int]:
# For chat completion use `apply_chat_template`
return self.tokenizer.encode(prompt, bos=True, eos=False)

def encode_with_suffix(self, prompt: str, suffix: str) -> List[int]:
fim = FIMRequest(prompt=prompt, suffix=suffix)
def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
fim = FIMRequest(prefix=prefix, suffix=suffix)
return self.mistral.encode_fim(fim)

def apply_chat_template(self,
Expand Down

0 comments on commit c4886c3

Please sign in to comment.