Skip to content

Commit

Permalink
Merge pull request #257 from l4b4r4b4b4/patch-2
Browse files Browse the repository at this point in the history
bug fix for `served_model_name` from `request.model`
  • Loading branch information
jeffreymeetkai authored Aug 22, 2024
2 parents a3400cd + c510631 commit df05374
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions functionary/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def create_error_response(


async def check_all_errors(request, served_model) -> Optional[JSONResponse]:
if request.model != served_model:
if request.model not in served_model:
return create_error_response(
status_code=HTTPStatus.NOT_FOUND,
message=f"The model `{request.model}` does not exist.",
Expand Down Expand Up @@ -144,7 +144,7 @@ async def process_chat_completion(
request: ChatCompletionRequest,
raw_request: Optional[Request],
tokenizer: Any,
served_model: str,
served_model: List[str],
engine_model_config: Any,
enable_grammar_sampling: bool,
engine: Any,
Expand Down Expand Up @@ -250,7 +250,7 @@ async def completion_stream_generator(
async for response in generate_openai_format_from_stream_async(
generator, prompt_template, tool_choice, tools_or_functions
):

# Convert tool_calls to function_call if request.functions is provided
if (
functions
Expand Down Expand Up @@ -290,8 +290,7 @@ async def completion_stream_generator(
}
if response["finish_reason"] == "function_call":
response["finish_reason"] = "tool_calls"



# Workaround Fixes
response["delta"]["role"] = "assistant"
if (
Expand All @@ -302,10 +301,11 @@ async def completion_stream_generator(
for tool_call in response["delta"]["tool_calls"]:
if tool_call.get("type") is None:
tool_call["type"] = "function"



chunk = StreamChoice(**response)
result = ChatCompletionChunk(id=request_id, choices=[chunk], model=served_model)
result = ChatCompletionChunk(
id=request_id, choices=[chunk], model=model_name
)
chunk_dic = result.dict(exclude_unset=True)
chunk_data = json.dumps(chunk_dic, ensure_ascii=False)
yield f"data: {chunk_data}\n\n"
Expand Down

0 comments on commit df05374

Please sign in to comment.