diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 8d6ad6706fb0e..64bcba67c3437 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -60,12 +60,13 @@ class MockServingChat: tokenizer: MockTokenizer -def test_load_chat_template(): +@pytest.mark.asyncio +async def test_load_chat_template(): # Testing chatml template tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=chatml_jinja_path) + await OpenAIServingChat._load_chat_template( + mock_serving_chat, chat_template=chatml_jinja_path) template_content = tokenizer.chat_template @@ -76,7 +77,8 @@ def test_load_chat_template(): {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 -def test_no_load_chat_template_filelike(): +@pytest.mark.asyncio +async def test_no_load_chat_template_filelike(): # Testing chatml template template = "../../examples/does_not_exist" tokenizer = MockTokenizer() @@ -84,18 +86,19 @@ def test_no_load_chat_template_filelike(): mock_serving_chat = MockServingChat(tokenizer) with pytest.raises(ValueError, match="looks like a file path"): - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + await OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) -def test_no_load_chat_template_literallike(): +@pytest.mark.asyncio +async def test_no_load_chat_template_literallike(): # Testing chatml template template = "{{ messages }}" tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + await OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) template_content = tokenizer.chat_template assert template_content == template @@ -110,8 +113,8 @@ async def test_get_gen_prompt(model, template, add_generation_prompt, # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + await OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py new file mode 100644 index 0000000000000..269b0823fec05 --- /dev/null +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -0,0 +1,37 @@ +import asyncio +from dataclasses import dataclass + +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + +MODEL_NAME = "openai-community/gpt2" +CHAT_TEMPLATE = "Dummy chat template for testing {}" + + +@dataclass +class MockModelConfig: + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + + +@dataclass +class MockEngine: + + async def get_model_config(self): + return MockModelConfig + + +async def _async_serving_chat_init(): + serving_completion = OpenAIServingChat(MockEngine(), + served_model_names=[MODEL_NAME], + response_role="assistant", + chat_template=CHAT_TEMPLATE) + return serving_completion + + +def test_async_serving_chat_init(): + serving_completion = asyncio.run(_async_serving_chat_init()) + assert serving_completion.tokenizer is not None + assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 1323dba469117..e53e64a0c1ff8 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -150,7 +150,7 @@ def server(zephyr_lora_files): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 599f99e56a726..c8f4a6b315db0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,3 +1,4 @@ +import asyncio import codecs import time from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, @@ -40,9 +41,11 @@ def __init__(self, chat_template: Optional[str] = None): super().__init__(engine=engine, served_model_names=served_model_names, - lora_modules=lora_modules) + lora_modules=lora_modules, + await_post_init=self._load_chat_template( + chat_template=chat_template)) + self.response_role = response_role - self._load_chat_template(chat_template) def _parse_chat_message_content( self, @@ -356,7 +359,10 @@ async def chat_completion_full_generator( return response - def _load_chat_template(self, chat_template: Optional[str]): + async def _load_chat_template(self, chat_template: Optional[str]): + while self.tokenizer is None: + # Give the parent class time to load the tokenizer + await asyncio.sleep(0.1) tokenizer = self.tokenizer if chat_template is not None: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 3d5ed328b9d19..21baea2e5e7f6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,7 @@ import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union from pydantic import Field from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -29,8 +29,11 @@ class LoRAModulePath: class OpenAIServing: - def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + def __init__(self, + engine: AsyncLLMEngine, + served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]], + await_post_init: Optional[Awaitable[Any]] = None): self.engine = engine self.served_model_names = served_model_names if lora_modules is None: @@ -56,12 +59,12 @@ def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], if event_loop is not None and event_loop.is_running(): # If the current is instanced by Ray Serve, # there is already a running event loop - event_loop.create_task(self._post_init()) + event_loop.create_task(self._post_init(await_post_init)) else: # When using single vLLM without engine_use_ray - asyncio.run(self._post_init()) + asyncio.run(self._post_init(await_post_init)) - async def _post_init(self): + async def _post_init(self, await_post_init): engine_model_config = await self.engine.get_model_config() self.max_model_len = engine_model_config.max_model_len @@ -73,6 +76,9 @@ async def _post_init(self): trust_remote_code=engine_model_config.trust_remote_code, truncation_side="left") + if await_post_init is not None: + await await_post_init + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [