Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: updated inline vllm inference provider #880

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 19 additions & 24 deletions llama_stack/providers/inline/inference/vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
# the root directory of this source tree.

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator

from llama_stack.providers.utils.inference import supported_inference_models
from pydantic import BaseModel, Field


@json_schema_type
class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider."""
"""Configuration for the vLLM inference provider.

Note that the model name is no longer part of this static configuration.
You can bind an instance of this provider to a specific model with the
``models.register()`` API call."""

model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
Expand All @@ -26,34 +24,31 @@ class VLLMConfig(BaseModel):
default=4096,
description="Maximum number of tokens to generate.",
)
max_model_len: int = Field(
default=4096, description="Maximum context length to use during serving."
)
max_num_seqs: int = Field(
default=4, description="Maximum parallel batch size for generation"
)
enforce_eager: bool = Field(
default=False,
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
)
gpu_memory_utilization: float = Field(
default=0.3,
description=(
"How much GPU memory will be allocated when this provider has finished "
"loading, including memory that was already allocated before loading."
),
)

@classmethod
def sample_run_config(cls):
return {
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.MAX_TOKENS:4096}",
"max_model_len": "${env.MAX_MODEL_LEN:4096}",
"max_num_seqs": "${env.MAX_NUM_SEQS:4}",
"enforce_eager": "${env.ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.7}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
}

@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()

descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(
f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]"
)
return model
186 changes: 186 additions & 0 deletions llama_stack/providers/inline/inference/vllm/openai_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import List, Optional

import vllm

from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition

from llama_stack.apis.inference import (
ChatCompletionRequest,
GrammarResponseFormat,
JsonSchemaResponseFormat,
Message,
ToolChoice,
UserMessage,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
)


###############################################################################
# This file contains OpenAI compatibility code that is currently only used
# by the inline vLLM connector. Some or all of this code may be moved to a
# central location at a later date.


def _merge_context_into_content(message: Message) -> Message: # type: ignore
"""
Merge the ``context`` field of a Llama Stack ``Message`` object into
the content field for compabilitiy with OpenAI-style APIs.

Generates a content string that emulates the current behavior
of ``llama_models.llama3.api.chat_format.encode_message()``.

:param message: Message that may include ``context`` field

:returns: A version of ``message`` with any context merged into the
``content`` field.
"""
if not isinstance(message, UserMessage): # Separate type check for linter
return message
if message.context is None:
return message
return UserMessage(
role=message.role,
# Emumate llama_models.llama3.api.chat_format.encode_message()
content=message.content + "\n\n" + message.context,
context=None,
)


def _llama_stack_tools_to_openai_tools(
tools: Optional[List[ToolDefinition]] = None,
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
"""
Convert the list of available tools from Llama Stack's format to vLLM's
version of OpenAI's format.
"""
if tools is None:
return []

result = []
for t in tools:
if isinstance(t.tool_name, BuiltinTool):
raise NotImplementedError("Built-in tools not yet implemented")
if t.parameters is None:
parameters = None
else: # if t.parameters is not None
# Convert the "required" flags to a list of required params
required_params = [k for k, v in t.parameters.items() if v.required]
parameters = {
"type": "object", # Mystery value that shows up in OpenAI docs
"properties": {
k: {"type": v.param_type, "description": v.description}
for k, v in t.parameters.items()
},
"required": required_params,
}

function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
name=t.tool_name, description=t.description, parameters=parameters
)

# Every tool definition is double-boxed in a ChatCompletionToolsParam
result.append(
vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(
function=function_def
)
)
return result


async def llama_stack_chat_completion_to_openai_chat_completion_dict(
request: ChatCompletionRequest,
) -> dict:
"""
Convert a chat completion request in Llama Stack format into an
equivalent set of arguments to pass to an OpenAI-compatible
chat completions API.

:param request: Bundled request parameters in Llama Stack format.

:returns: Dictionary of key-value pairs to use as an initializer
for a dataclass or to be converted directly to JSON and sent
over the wire.
"""

converted_messages = [
# This mystery async call makes the parent function also be async
await convert_message_to_openai_dict(
_merge_context_into_content(m), download=True
)
for m in request.messages
]
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)

# Llama will try to use built-in tools with no tool catalog, so don't enable
# tool choice unless at least one tool is enabled.
converted_tool_choice = "none"
if (
request.tool_choice == ToolChoice.auto
and request.tools is not None
and len(request.tools) > 0
):
converted_tool_choice = "auto"

# TODO: Figure out what to do with the tool_prompt_format argument.
# Other connectors appear to drop it quietly.

# Use Llama Stack shared code to translate sampling parameters.
sampling_options = get_sampling_options(request.sampling_params)

# get_sampling_options() translates repetition penalties to an option that
# OpenAI's APIs don't know about.
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
# For now, translate repetition penalties into a format that vLLM's broken
# API will handle correctly. Two wrongs make a right...
if "repeat_penalty" in sampling_options:
del sampling_options["repeat_penalty"]
if (
request.sampling_params.repetition_penalty is not None
and request.sampling_params.repetition_penalty != 1.0
):
sampling_options["repetition_penalty"] = (
request.sampling_params.repetition_penalty
)

# Convert a single response format into four different parameters, per
# the OpenAI spec
guided_decoding_options = dict()
if request.response_format is None:
# Use defaults
pass
elif isinstance(request.response_format, JsonSchemaResponseFormat):
guided_decoding_options["guided_json"] = request.response_format.json_schema
elif isinstance(request.response_format, GrammarResponseFormat):
guided_decoding_options["guided_grammar"] = request.response_format.bnf
else:
raise TypeError(
f"ResponseFormat object is of unexpected "
f"subtype '{type(request.response_format)}'"
)

logprob_options = dict()
if request.logprobs is not None:
logprob_options["logprobs"] = request.logprobs.top_k

# Marshall together all the arguments for a ChatCompletionRequest
request_options = {
"model": request.model,
"messages": converted_messages,
"tools": converted_tools,
"tool_choice": converted_tool_choice,
"stream": request.stream,
**sampling_options,
**guided_decoding_options,
**logprob_options,
}

return request_options
Loading
Loading