Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Frontend] Move async logic outside of constructor (vllm-project#4674)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and robertgshaw2-redhat committed May 19, 2024
1 parent b0d3937 commit 294e480
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 102 deletions.
30 changes: 13 additions & 17 deletions tests/async_engine/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,12 @@ class MockServingChat:
tokenizer: MockTokenizer


@pytest.mark.asyncio
async def test_load_chat_template():
def test_load_chat_template():
# Testing chatml template
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
await OpenAIServingChat._load_chat_template(
mock_serving_chat, chat_template=chatml_jinja_path)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=chatml_jinja_path)

template_content = tokenizer.chat_template

Expand All @@ -77,44 +76,41 @@ async def test_load_chat_template():
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501


@pytest.mark.asyncio
async def test_no_load_chat_template_filelike():
def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
tokenizer = MockTokenizer()

mock_serving_chat = MockServingChat(tokenizer)

with pytest.raises(ValueError, match="looks like a file path"):
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)


@pytest.mark.asyncio
async def test_no_load_chat_template_literallike():
def test_no_load_chat_template_literallike():
# Testing chatml template
template = "{{ messages }}"
tokenizer = MockTokenizer()

mock_serving_chat = MockServingChat(tokenizer)
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
template_content = tokenizer.chat_template

assert template_content == template


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model,template,add_generation_prompt,expected_output",
MODEL_TEMPLATE_GENERATON_OUTPUT)
async def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output):
def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output):
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
mock_serving_chat = MockServingChat(tokenizer)
await OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)

# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
Expand Down
8 changes: 6 additions & 2 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@ class MockModelConfig:
class MockEngine:

async def get_model_config(self):
return MockModelConfig
return MockModelConfig()


async def _async_serving_chat_init():
serving_completion = OpenAIServingChat(MockEngine(),
engine = MockEngine()
model_config = await engine.get_model_config()

serving_completion = OpenAIServingChat(engine,
model_config,
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE)
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def add_cli_args(
return parser

@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
def from_cli_args(cls, args: argparse.Namespace):
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
Expand Down
23 changes: 20 additions & 3 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Set
from typing import Optional, Set

import fastapi
import uvicorn
Expand Down Expand Up @@ -164,15 +164,32 @@ async def authentication(request: Request, call_next):
served_model_names = args.served_model_name
else:
served_model_names = [args.model]

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
openai_serving_chat = OpenAIServingChat(engine, served_model_names,

event_loop: Optional[asyncio.AbstractEventLoop]
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None

if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
model_config = event_loop.run_until_complete(engine.get_model_config())
else:
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())

openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, served_model_names, args.lora_modules)
engine, model_config, served_model_names, args.lora_modules)

app.root_path = args.root_path
uvicorn.run(app,
Expand Down
72 changes: 35 additions & 37 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import codecs
import time
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
Expand All @@ -8,6 +7,7 @@
from openai.types.chat import (ChatCompletionContentPartParam,
ChatCompletionRole)

from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
Expand Down Expand Up @@ -35,17 +35,47 @@ class OpenAIServingChat(OpenAIServing):

def __init__(self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
await_post_init=self._load_chat_template(
chat_template=chat_template))
lora_modules=lora_modules)

self.response_role = response_role
self._load_chat_template(chat_template)

def _load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer

if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")

logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")

def _parse_chat_message_content(
self,
Expand Down Expand Up @@ -357,36 +387,4 @@ async def chat_completion_full_generator(
usage=usage,
)

return response

async def _load_chat_template(self, chat_template: Optional[str]):
while self.tokenizer is None:
# Give the parent class time to load the tokenizer
await asyncio.sleep(0.1)
tokenizer = self.tokenizer

if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")

logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")
return response
7 changes: 4 additions & 3 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fastapi import Request

from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
Expand Down Expand Up @@ -52,11 +53,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:

class OpenAIServingCompletion(OpenAIServing):

def __init__(self,
engine: AsyncLLMEngine,
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None):
lora_modules: Optional[List[LoRAModulePath]]):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)

Expand Down
56 changes: 17 additions & 39 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
import json
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated

from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, ErrorResponse,
Expand All @@ -29,13 +28,24 @@ class LoRAModulePath:

class OpenAIServing:

def __init__(self,
engine: AsyncLLMEngine,
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
await_post_init: Optional[Awaitable[Any]] = None):
lora_modules: Optional[List[LoRAModulePath]]):
super().__init__()

self.engine = engine
self.max_model_len = model_config.max_model_len

# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
tokenizer_revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
truncation_side="left")

self.served_model_names = served_model_names

if lora_modules is None:
self.lora_requests = []
else:
Expand All @@ -47,38 +57,6 @@ def __init__(self,
) for i, lora in enumerate(lora_modules, start=1)
]

self.max_model_len = 0
# Lazy initialized
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None

if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop.create_task(self._post_init(await_post_init))
else:
# When using single vLLM without engine_use_ray
asyncio.run(self._post_init(await_post_init))

async def _post_init(self, await_post_init):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len

# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
tokenizer_revision=engine_model_config.tokenizer_revision,
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")

if await_post_init is not None:
await await_post_init

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
Expand Down

0 comments on commit 294e480

Please sign in to comment.