Skip to content

Commit

Permalink
feat(agents-api): Make chat route tests pass
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 12, 2024
1 parent 7ab4075 commit 5f91ac3
Show file tree
Hide file tree
Showing 34 changed files with 258 additions and 408 deletions.
8 changes: 4 additions & 4 deletions agents-api/agents_api/autogen/Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .Common import LogitBias
from .Docs import DocReference
from .Entries import ChatMLMessage, InputChatMLMessage
from .Entries import InputChatMLMessage
from .Tools import FunctionTool, NamedToolChoice


Expand Down Expand Up @@ -90,7 +90,7 @@ class ChatOutputChunk(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
delta: ChatMLMessage
delta: InputChatMLMessage
"""
The message generated by the model
"""
Expand Down Expand Up @@ -166,7 +166,7 @@ class MultipleChatOutput(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
messages: list[ChatMLMessage]
messages: list[InputChatMLMessage]


class OpenAISettings(BaseModel):
Expand Down Expand Up @@ -199,7 +199,7 @@ class SingleChatOutput(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
message: ChatMLMessage
message: InputChatMLMessage


class TokenLogProb(BaseTokenLogProb):
Expand Down
41 changes: 2 additions & 39 deletions agents-api/agents_api/autogen/Entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class BaseEntry(BaseModel):
)
role: Literal[
"user",
"agent",
"assistant",
"system",
"function",
"function_response",
Expand Down Expand Up @@ -67,43 +67,6 @@ class ChatMLImageContentPart(BaseModel):
"""


class ChatMLMessage(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
role: Literal[
"user",
"agent",
"system",
"function",
"function_response",
"function_call",
"auto",
]
"""
The role of the message
"""
content: str | list[str] | list[ChatMLTextContentPart | ChatMLImageContentPart]
"""
The content parts of the message
"""
name: str | None = None
"""
Name
"""
tool_calls: Annotated[
list[ChosenToolCall], Field([], json_schema_extra={"readOnly": True})
]
"""
Tool calls generated by the model.
"""
created_at: Annotated[AwareDatetime, Field(json_schema_extra={"readOnly": True})]
"""
When this resource was created as UTC date-time
"""
id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})]


class ChatMLTextContentPart(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
Expand Down Expand Up @@ -159,7 +122,7 @@ class InputChatMLMessage(BaseModel):
)
role: Literal[
"user",
"agent",
"assistant",
"system",
"function",
"function_response",
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_active_agent(self) -> Agent:
"""
Get the active agent from the session data.
"""
requested_agent: UUID | None = self.settings.agent
requested_agent: UUID | None = self.settings and self.settings.agent

if requested_agent:
assert requested_agent in [agent.id for agent in self.agents], (
Expand All @@ -67,15 +67,15 @@ def get_active_agent(self) -> Agent:
return self.agents[0]

def merge_settings(self, chat_input: ChatInput) -> ChatSettings:
request_settings = ChatSettings.model_validate(chat_input)
request_settings = chat_input.model_dump(exclude_unset=True)
active_agent = self.get_active_agent()
default_settings = active_agent.default_settings

self.settings = settings = ChatSettings(
**{
"model": active_agent.model,
**default_settings.model_dump(),
**request_settings.model_dump(exclude_unset=True),
**request_settings,
}
)

Expand Down
39 changes: 34 additions & 5 deletions agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@ async def render_template_string(
return rendered


async def render_template_chatml(
messages: list[dict], variables: dict, check: bool = False
) -> list[dict]:
# Parse template
# FIXME: should template_strings contain a list of ChatMLTextContentPart? Should we handle it somehow?
templates = [jinja_env.from_string(msg["content"]) for msg in messages]

# If check is required, get required vars from template and validate variables
if check:
for template in templates:
schema = to_json_schema(infer(template))
validate(instance=variables, schema=schema)

# Render
rendered = [
({**msg, "content": await template.render_async(**variables)})
for template, msg in zip(templates, messages)
]

return rendered


async def render_template_parts(
template_strings: list[dict], variables: dict, check: bool = False
) -> list[dict]:
Expand Down Expand Up @@ -73,7 +95,7 @@ async def render_template_parts(


async def render_template(
template_string: str | list[dict],
input: str | list[dict],
variables: dict,
check: bool = False,
skip_vars: list[str] | None = None,
Expand All @@ -83,8 +105,15 @@ async def render_template(
for name, val in variables.items()
if not (skip_vars is not None and isinstance(name, str) and name in skip_vars)
}
if isinstance(template_string, str):
return await render_template_string(template_string, variables, check)

elif isinstance(template_string, list):
return await render_template_parts(template_string, variables, check)
match input:
case str():
future = render_template_string(input, variables, check)

case [{"content": str()}, *_]:
future = render_template_chatml(input, variables, check)

case _:
future = render_template_parts(input, variables, check)

return await future
3 changes: 3 additions & 0 deletions agents-api/agents_api/models/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def dbsf_normalize(scores: list[float]) -> list[float]:
Scores scaled using minmax scaler with our custom feature range
(extremes indicated as 3 standard deviations from the mean)
"""
if len(scores) < 2:
return scores

sd = stdev(scores)
if sd == 0:
return scores
Expand Down
105 changes: 63 additions & 42 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from ...autogen.openapi_model import (
ChatInput,
ChatResponse,
ChunkChatResponse,
CreateEntryRequest,
DocReference,
History,
MessageChatResponse,
)
from ...clients.embed import embed
from ...clients.litellm import acompletion
from ...clients import embed, litellm
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ...common.utils.datetime import utcnow
from ...common.utils.template import render_template
from ...dependencies.developer_id import get_developer_data
from ...models.docs.search_docs_hybrid import search_docs_hybrid
Expand All @@ -24,28 +26,14 @@
from .router import router


@router.post(
"/sessions/{session_id}/chat",
status_code=HTTP_201_CREATED,
tags=["sessions", "chat"],
)
async def chat(
developer: Annotated[Developer, Depends(get_developer_data)],
async def get_messages(
*,
developer: Developer,
session_id: UUID,
data: ChatInput,
background_tasks: BackgroundTasks,
) -> ChatResponse:
# First get the chat context
chat_context: ChatContext = prepare_chat_context(
developer_id=developer.id,
session_id=session_id,
)
assert isinstance(chat_context, ChatContext)

# Merge the settings and prepare environment
chat_context.merge_settings(data)
settings: dict = chat_context.settings.model_dump()
env: dict = chat_context.get_chat_environment()
new_raw_messages: list[dict],
chat_context: ChatContext,
):
assert len(new_raw_messages) > 0

# Get the session history
history: History = get_history(
Expand All @@ -62,10 +50,8 @@ async def chat(
if entry.id not in {r.head for r in relations}
]

new_raw_messages = [msg.model_dump() for msg in data.messages]

# Search matching docs
[query_embedding, *_] = await embed(
[query_embedding, *_] = await embed.embed(
inputs=[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
Expand All @@ -82,39 +68,74 @@ async def chat(
query_embedding=query_embedding,
)

return past_messages, doc_references


@router.post(
"/sessions/{session_id}/chat",
status_code=HTTP_201_CREATED,
tags=["sessions", "chat"],
)
async def chat(
developer: Annotated[Developer, Depends(get_developer_data)],
session_id: UUID,
data: ChatInput,
background_tasks: BackgroundTasks,
) -> ChatResponse:
# First get the chat context
chat_context: ChatContext = prepare_chat_context(
developer_id=developer.id,
session_id=session_id,
)

# Merge the settings and prepare environment
chat_context.merge_settings(data)
settings: dict = chat_context.settings.model_dump()
env: dict = chat_context.get_chat_environment()
new_raw_messages = [msg.model_dump() for msg in data.messages]

# Render the messages
past_messages, doc_references = await get_messages(
developer=developer,
session_id=session_id,
new_raw_messages=new_raw_messages,
chat_context=chat_context,
)

env["docs"] = doc_references
new_messages = render_template(new_raw_messages, variables=env)
new_messages = await render_template(new_raw_messages, variables=env)
messages = past_messages + new_messages

# Get the response from the model
model_response = await acompletion(
model_response = await litellm.acompletion(
messages=messages,
**settings,
user=str(developer.id),
tags=developer.tags,
)

# Save the input and the response to the session history
new_entries = [CreateEntryRequest(**msg) for msg in new_messages]
background_tasks.add_task(
create_entries,
developer_id=developer.id,
session_id=session_id,
data=new_entries,
mark_session_as_updated=True,
)
if data.save:
new_entries = [
CreateEntryRequest(**msg, source="api_request") for msg in new_messages
]
background_tasks.add_task(
create_entries,
developer_id=developer.id,
session_id=session_id,
data=new_entries,
mark_session_as_updated=True,
)

# Return the response
response_json = model_response.model_dump()
response_json.pop("id", None)

chat_response: ChatResponse = ChatResponse(
**response_json,
chat_response_class = ChunkChatResponse if data.stream else MessageChatResponse
chat_response: ChatResponse = chat_response_class(
id=uuid4(),
created_at=model_response.created,
created_at=utcnow(),
jobs=[],
docs=doc_references,
usage=model_response.usage.model_dump(),
choices=[choice.model_dump() for choice in model_response.choices],
)

return chat_response
6 changes: 3 additions & 3 deletions agents-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 5f91ac3

Please sign in to comment.