Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/correct streaming resource lock #1879

Merged
merged 5 commits into from
Jan 8, 2025
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 112 additions & 121 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from anyio import Lock
from functools import partial
from typing import Iterator, List, Optional, Union, Dict
from typing import List, Optional, Union, Dict

import llama_cpp

Expand Down Expand Up @@ -155,34 +155,74 @@ def create_app(
return app


def prepare_request_resources(
body: CreateCompletionRequest | CreateChatCompletionRequest,
llama_proxy: LlamaProxy,
body_model: str,
kwargs,
) -> llama_cpp.Llama:
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is not available",
)
llama = llama_proxy(body_model)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
return llama


async def get_event_publisher(
request: Request,
inner_send_chan: MemoryObjectSendStream[typing.Any],
iterator: Iterator[typing.Any],
on_complete: typing.Optional[typing.Callable[[], typing.Awaitable[None]]] = None,
body: CreateCompletionRequest | CreateChatCompletionRequest,
body_model: str,
llama_call,
kwargs,
):
server_settings = next(get_server_settings())
interrupt_requests = (
server_settings.interrupt_requests if server_settings else False
)
async with inner_send_chan:
try:
async for chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
if interrupt_requests and llama_outer_lock.locked():
await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(f"Disconnected from client (via refresh/close) {request.client}")
raise e
finally:
if on_complete:
await on_complete()
async with contextlib.AsyncExitStack() as exit_stack:
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(
contextlib.asynccontextmanager(get_llama_proxy)()
)
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
async with inner_send_chan:
try:
iterator = await run_in_threadpool(llama_call, llama, **kwargs)
async for chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
gjpower marked this conversation as resolved.
Show resolved Hide resolved
if interrupt_requests and llama_outer_lock.locked():
await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
gjpower marked this conversation as resolved.
Show resolved Hide resolved
with anyio.move_on_after(1, shield=True):
print(
f"Disconnected from client (via refresh/close) {request.client}"
)
raise e


def _logit_bias_tokens_to_input_ids(
Expand Down Expand Up @@ -267,18 +307,11 @@ async def create_completion(
request: Request,
body: CreateCompletionRequest,
) -> llama_cpp.Completion:
exit_stack = contextlib.AsyncExitStack()
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is not available",
)
if isinstance(body.prompt, list):
assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""

llama = llama_proxy(
body_model = (
body.model
if request.url.path != "/v1/engines/copilot-codex/completions"
else "copilot-codex"
Expand All @@ -293,60 +326,41 @@ async def create_completion(
}
kwargs = body.model_dump(exclude=exclude)

if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

try:
iterator_or_completion: Union[
llama_cpp.CreateCompletionResponse,
Iterator[llama_cpp.CreateCompletionStreamResponse],
] = await run_in_threadpool(llama, **kwargs)
except Exception as err:
await exit_stack.aclose()
raise err

if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)

# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
yield first_response
yield from iterator_or_completion

# handle streaming request
if kwargs.get("stream", False):
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
on_complete=exit_stack.aclose,
body=body,
body_model=body_model,
llama_call=llama_cpp.Llama.__call__,
kwargs=kwargs,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
await exit_stack.aclose()
return iterator_or_completion

# handle regular request
async with contextlib.AsyncExitStack() as exit_stack:
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(
contextlib.asynccontextmanager(get_llama_proxy)()
)
gjpower marked this conversation as resolved.
Show resolved Hide resolved
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)

if await request.is_disconnected():
print(
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Client closed request",
)

return await run_in_threadpool(llama, **kwargs)


@router.post(
Expand Down Expand Up @@ -474,74 +488,51 @@ async def create_chat_completion(
# where the dependency is cleaned up before a StreamingResponse
# is complete.
# https://github.com/tiangolo/fastapi/issues/11143
exit_stack = contextlib.AsyncExitStack()
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is not available",
)

body_model = body.model
exclude = {
"n",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)
llama = llama_proxy(body.model)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

try:
iterator_or_completion: Union[
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
except Exception as err:
await exit_stack.aclose()
raise err

if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)

# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
yield first_response
yield from iterator_or_completion

# handle streaming request
if kwargs.get("stream", False):
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
on_complete=exit_stack.aclose,
body=body,
body_model=body_model,
llama_call=llama_cpp.Llama.create_chat_completion,
kwargs=kwargs,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
await exit_stack.aclose()
return iterator_or_completion

# handle regular request
async with contextlib.AsyncExitStack() as exit_stack:
llama_proxy: LlamaProxy = await exit_stack.enter_async_context(
contextlib.asynccontextmanager(get_llama_proxy)()
)
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)

if await request.is_disconnected():
print(
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Client closed request",
)

return await run_in_threadpool(llama.create_chat_completion, **kwargs)


@router.get(
Expand Down
Loading