Skip to content

Commit

Permalink
fix: streaming resource lock (#1879)
Browse files Browse the repository at this point in the history
* fix: correct issue with handling lock during streaming

move locking for streaming into get_event_publisher call so it is locked and unlocked in the correct task for the streaming reponse

* fix: simplify exit stack management for create_chat_completion and create_completion

* fix: correct missing `async with` and format code

* fix: remove unnecessary explicit use of AsyncExitStack

fix: correct type hints for body_model

---------

Co-authored-by: Andrei <[email protected]>
  • Loading branch information
gjpower and abetlen authored Jan 8, 2025
1 parent 1d5f534 commit e8f14ce
Showing 1 changed file with 103 additions and 121 deletions.
224 changes: 103 additions & 121 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from anyio import Lock
from functools import partial
from typing import Iterator, List, Optional, Union, Dict
from typing import List, Optional, Union, Dict

import llama_cpp

Expand Down Expand Up @@ -155,34 +155,71 @@ def create_app(
return app


def prepare_request_resources(
body: CreateCompletionRequest | CreateChatCompletionRequest,
llama_proxy: LlamaProxy,
body_model: str | None,
kwargs,
) -> llama_cpp.Llama:
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is not available",
)
llama = llama_proxy(body_model)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
return llama


async def get_event_publisher(
request: Request,
inner_send_chan: MemoryObjectSendStream[typing.Any],
iterator: Iterator[typing.Any],
on_complete: typing.Optional[typing.Callable[[], typing.Awaitable[None]]] = None,
body: CreateCompletionRequest | CreateChatCompletionRequest,
body_model: str | None,
llama_call,
kwargs,
):
server_settings = next(get_server_settings())
interrupt_requests = (
server_settings.interrupt_requests if server_settings else False
)
async with inner_send_chan:
try:
async for chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
if interrupt_requests and llama_outer_lock.locked():
await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(f"Disconnected from client (via refresh/close) {request.client}")
raise e
finally:
if on_complete:
await on_complete()
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
async with inner_send_chan:
try:
iterator = await run_in_threadpool(llama_call, llama, **kwargs)
async for chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
if interrupt_requests and llama_outer_lock.locked():
await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(
f"Disconnected from client (via refresh/close) {request.client}"
)
raise e


def _logit_bias_tokens_to_input_ids(
Expand Down Expand Up @@ -267,18 +304,11 @@ async def create_completion(
request: Request,
body: CreateCompletionRequest,
) -> llama_cpp.Completion:
exit_stack = contextlib.AsyncExitStack()
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is not available",
)
if isinstance(body.prompt, list):
assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""

llama = llama_proxy(
body_model = (
body.model
if request.url.path != "/v1/engines/copilot-codex/completions"
else "copilot-codex"
Expand All @@ -293,60 +323,38 @@ async def create_completion(
}
kwargs = body.model_dump(exclude=exclude)

if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

try:
iterator_or_completion: Union[
llama_cpp.CreateCompletionResponse,
Iterator[llama_cpp.CreateCompletionStreamResponse],
] = await run_in_threadpool(llama, **kwargs)
except Exception as err:
await exit_stack.aclose()
raise err

if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)

# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
yield first_response
yield from iterator_or_completion

# handle streaming request
if kwargs.get("stream", False):
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
on_complete=exit_stack.aclose,
body=body,
body_model=body_model,
llama_call=llama_cpp.Llama.__call__,
kwargs=kwargs,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
await exit_stack.aclose()
return iterator_or_completion

# handle regular request
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)

if await request.is_disconnected():
print(
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Client closed request",
)

return await run_in_threadpool(llama, **kwargs)


@router.post(
Expand Down Expand Up @@ -474,74 +482,48 @@ async def create_chat_completion(
# where the dependency is cleaned up before a StreamingResponse
# is complete.
# https://github.com/tiangolo/fastapi/issues/11143
exit_stack = contextlib.AsyncExitStack()
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is not available",
)

body_model = body.model
exclude = {
"n",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)
llama = llama_proxy(body.model)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

try:
iterator_or_completion: Union[
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
except Exception as err:
await exit_stack.aclose()
raise err

if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)

# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
yield first_response
yield from iterator_or_completion

# handle streaming request
if kwargs.get("stream", False):
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
on_complete=exit_stack.aclose,
body=body,
body_model=body_model,
llama_call=llama_cpp.Llama.create_chat_completion,
kwargs=kwargs,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
await exit_stack.aclose()
return iterator_or_completion

# handle regular request
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)

if await request.is_disconnected():
print(
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Client closed request",
)

return await run_in_threadpool(llama.create_chat_completion, **kwargs)


@router.get(
Expand Down

0 comments on commit e8f14ce

Please sign in to comment.