diff --git a/engines/python/setup/djl_python/chat_completions/vllm_chat_properties.py b/engines/python/setup/djl_python/chat_completions/vllm_chat_properties.py new file mode 100644 index 000000000..b8daee9e0 --- /dev/null +++ b/engines/python/setup/djl_python/chat_completions/vllm_chat_properties.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +from typing import Optional +from pydantic import Field +from vllm.entrypoints.openai.protocol import ChatCompletionRequest + + +class ChatProperties(ChatCompletionRequest): + """ + Chat input parameters for chat completions API. + See https://platform.openai.com/docs/api-reference/chat/create + """ + + model: Optional[str] = Field(default=None, exclude=True) # Unused diff --git a/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py b/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py new file mode 100644 index 000000000..8cea73c56 --- /dev/null +++ b/engines/python/setup/djl_python/chat_completions/vllm_chat_utils.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +from typing import Dict, List, Optional, Union + +from djl_python.chat_completions.vllm_chat_properties import ChatProperties +from djl_python.properties_manager.properties import Properties +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages) + + +def is_chat_completions_request(inputs: Dict) -> bool: + return "messages" in inputs + + +def parse_chat_completions_request_vllm( + input_map: Dict, + is_rolling_batch: bool, + rolling_batch, + tokenizer, + chat_template: Optional[str] = None, + image_token: Optional[str] = None, + configs: Properties = None, + is_mistral_tokenizer: bool = False, +): + # Chat completions can either be a rolling batch or no-batching . + if not (is_rolling_batch or configs.batch_size == 1): + raise ValueError( + "chat completions support is not currently available for dynamic batching. " + "You must enable rolling batch to use the chat completions format." + ) + + if not is_mistral_tokenizer and not hasattr(tokenizer, + "apply_chat_template"): + raise AttributeError( + f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, " + f"please ensure that your tokenizer supports chat templates.") + + chat_params = ChatProperties(**input_map) + exclude = {"messages"} + param = chat_params.model_dump(exclude_none=True, exclude=exclude) + + conversation, mm_data = parse_chat_messages( + chat_params.messages, rolling_batch.get_model_config(), tokenizer) + + prompt_data: Union[str, List[int]] + if is_mistral_tokenizer: + text_inputs = apply_mistral_chat_template( + tokenizer, + messages=chat_params.messages, + chat_template=chat_template, + add_generation_prompt=True, + ) + else: + text_inputs = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=chat_template, + add_generation_prompt=True, + ) + + param["details"] = True # Enable details for chat completions + param[ + "output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat" + + if mm_data: + param["mm_data"] = mm_data + + # In the case of mistral, text_inputs = List[TokenIds], else = str + return text_inputs, param diff --git a/engines/python/setup/djl_python/input_parser.py b/engines/python/setup/djl_python/input_parser.py index 6f87e2dd9..401160345 100644 --- a/engines/python/setup/djl_python/input_parser.py +++ b/engines/python/setup/djl_python/input_parser.py @@ -16,6 +16,7 @@ from djl_python import Input from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request +from djl_python.chat_completions.vllm_chat_utils import parse_chat_completions_request_vllm from djl_python.encode_decode import decode from djl_python.properties_manager.properties import is_rolling_batch_enabled from djl_python.request import Request @@ -140,14 +141,27 @@ def parse_text_inputs_params(request_input: TextInput, input_item: Input, if configs is not None: is_bedrock = configs.bedrock_compat if is_chat_completions_request(input_map): - inputs, param = parse_chat_completions_request( - input_map, - kwargs.get("is_rolling_batch"), - tokenizer, - image_token=image_token, - configs=configs, - is_mistral_tokenizer=is_mistral_tokenizer, - ) + rolling_batch = kwargs.get("rolling_batch") + if rolling_batch is not None and rolling_batch.use_vllm_chat_completions( + ): + inputs, param = parse_chat_completions_request_vllm( + input_map, + kwargs.get("is_rolling_batch"), + rolling_batch, + tokenizer, + image_token=image_token, + configs=configs, + is_mistral_tokenizer=is_mistral_tokenizer, + ) + else: + inputs, param = parse_chat_completions_request( + input_map, + kwargs.get("is_rolling_batch"), + tokenizer, + image_token=image_token, + configs=configs, + is_mistral_tokenizer=is_mistral_tokenizer, + ) elif is_bedrock: inputs, param = parse_3p_request(input_map, kwargs.get("is_rolling_batch"), diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index d43cf6c9b..5f78489d7 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -126,6 +126,19 @@ def get_tokenizer(self): return self.engine.preprocessor.tokenizer return self.engine.preprocessor.tokenizer.tokenizer + def get_model_config(self): + # TODO: this is a hack right now to get the model config from the engine. We should expose this as + # an interface method and retrieve it from there after v12 + return self.engine.preprocessor.model_config if not self.is_t5_model else None + + def use_vllm_chat_completions(self): + return True + + def get_huggingface_model_config(self): + # TODO: this is a hack right now to get the model config from the engine. We should expose this as + # an interface method and retrieve it from there after v12 + return self.engine.preprocessor.model_config.hf_config if not self.is_t5_model else None + def get_huggingface_model_config(self): # TODO: this is a hack right now to get the model config from the engine. We should expose this as # an interface method and retrieve it from there after v12 diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index 30f5f68d0..66e566a68 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -101,6 +101,12 @@ def get_tokenizer(self): """ raise RuntimeError("get_tokenizer function not supported") + def get_model_config(self): + """ + :return: the model config if available + """ + raise RuntimeError("get_model_config must be implemented by subclass") + def get_huggingface_model_config(self): """ :return: the huggingface pretrained config if available @@ -108,6 +114,12 @@ def get_huggingface_model_config(self): raise RuntimeError( "get_huggingface_model_config must be implemented by subclass") + def use_vllm_chat_completions(self): + """ + :return: whether to use the vllm chat completions. + """ + return False + @abstractmethod def inference(self, new_requests: List[Request]) -> List: """ diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index bad3cc8eb..9e9f3b6ae 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -304,18 +304,9 @@ def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs: ) -def get_multi_modal_data(request: Request) -> Optional[dict]: - parameters = request.parameters - images = parameters.pop("images", None) - multi_modal_data = None - if images: - multi_modal_data = {"image": images} - return multi_modal_data - - def get_prompt_inputs(request: Request): text_prompt = request.request_input.input_text - multi_modal_data = get_multi_modal_data(request) + multi_modal_data = request.parameters.pop("mm_data", None) # TODO: In chat cases, we need to apply the chat template to the messages object to get a string # In both HuggingFace and mistral cases, that process can also yield token-ids directly # that we may want to consider passing directly to the engine diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 66abbf811..6745ced4b 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -57,9 +57,15 @@ def __init__(self, model_id_or_path: str, properties: dict, def get_tokenizer(self): return self.engine.tokenizer.tokenizer + def get_model_config(self): + return self.engine.model_config + def get_huggingface_model_config(self): return self.engine.model_config.hf_config + def use_vllm_chat_completions(self): + return True + def reset(self) -> None: """ Aborts all requests diff --git a/engines/python/setup/setup.py b/engines/python/setup/setup.py index 6883bbf89..1cf04e748 100644 --- a/engines/python/setup/setup.py +++ b/engines/python/setup/setup.py @@ -58,7 +58,7 @@ def run(self): test_requirements = [ 'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops', 'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf', - 'pydantic>=2.0', "objgraph" + 'pydantic>=2.0', "objgraph", "vllm==0.6.3.post1" ] setup(name='djl_python', diff --git a/tests/integration/llm/client.py b/tests/integration/llm/client.py index 1ea80d093..be37ea68b 100644 --- a/tests/integration/llm/client.py +++ b/tests/integration/llm/client.py @@ -1918,7 +1918,7 @@ def get_multimodal_prompt(batch_size): "messages": messages, "temperature": 0.9, "top_p": 0.6, - "max_new_tokens": 512, + "max_tokens": 512, }