Skip to content

Commit

Permalink
chore: Refactor the evaluator and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Feb 26, 2025
1 parent ea7b342 commit 8e299d8
Show file tree
Hide file tree
Showing 3 changed files with 575 additions and 38 deletions.
11 changes: 7 additions & 4 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ...queries.entries.create_entries import create_entries
from ...queries.sessions.count_sessions import count_sessions as count_sessions_query
from ..utils.model_validation import validate_model
from ..utils.tools import eval_tool_calls
from ..utils.tools import tool_calls_evaluator
from .metrics import total_tokens_per_user
from .router import router

Expand Down Expand Up @@ -203,9 +203,12 @@ async def chat(
"tags": developer.tags,
"custom_api_key": x_custom_api_key,
}
model_response = await eval_tool_calls(
litellm.acompletion, {"system"}, developer.id, **{**settings, **params}
)
evaluator = tool_calls_evaluator(tool_types={"system"}, developer_id=developer.id)
acompletion = evaluator(litellm.acompletion)
model_response = await acompletion(**{
**settings,
**params,
})

# Save the input and the response to the session history
if chat_input.save:
Expand Down
78 changes: 44 additions & 34 deletions agents-api/agents_api/routers/utils/tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from collections.abc import Awaitable, Callable
from functools import partial
from functools import partial, wraps
from typing import Any
from uuid import UUID

Expand Down Expand Up @@ -118,17 +118,19 @@ async def call_tool(developer_id: UUID, tool_name: str, arguments: dict):

connection_pool = getattr(app.state, "postgres_pool", None)
tool_handler = partial(tool_handler, connection_pool=connection_pool)
arguments["developer_id"] = str(developer_id)
arguments["developer_id"] = developer_id

# Convert all UUIDs to UUID objects
uuid_fields = ["agent_id", "user_id", "task_id", "session_id", "doc_id"]
for field in uuid_fields:
if field in arguments:
arguments[field] = UUID(arguments[field])
fld = arguments[field]
if isinstance(fld, str):
arguments[field] = UUID(fld)

parts = tool_name.split(".")
if len(parts) < MIN_TOOL_NAME_SEGMENTS:
msg = f"wrong syste tool name: {tool_name}"
msg = f"invalid system tool name: {tool_name}"
raise NameError(msg)

resource, subresource, operation = parts[0], None, parts[-1]
Expand Down Expand Up @@ -220,39 +222,47 @@ async def call_tool(developer_id: UUID, tool_name: str, arguments: dict):
return await tool_handler(**arguments)


async def eval_tool_calls(
func: Callable[..., Awaitable[ModelResponse | CustomStreamWrapper]],
def tool_calls_evaluator(
*,
tool_types: set[str],
developer_id: UUID,
**kwargs,
):
response: ModelResponse | CustomStreamWrapper | None = None
done = False
while not done:
response: ModelResponse | CustomStreamWrapper = await func(**kwargs)
if not response.choices or not response.choices[0].message.tool_calls:
def decor(
func: Callable[..., Awaitable[ModelResponse | CustomStreamWrapper]],
):
@wraps(func)
async def wrapper(**kwargs):
response: ModelResponse | CustomStreamWrapper | None = None
done = False
while not done:
response: ModelResponse | CustomStreamWrapper = await func(**kwargs)
if not response.choices or not response.choices[0].message.tool_calls:
return response

# TODO: add streaming response handling
for tool in response.choices[0].message.tool_calls:
if tool.type not in tool_types:
done = True
continue

done = False
# call a tool
tool_name = tool.function.name
tool_args = json.loads(tool.function.arguments)
tool_response = await call_tool(developer_id, tool_name, tool_args)

# append result to messages from previous step
messages: list = kwargs.get("messages", [])
messages.append({
"tool_call_id": tool.id,
"role": "tool",
"name": tool_name,
"content": tool_response,
})
kwargs["messages"] = messages

return response

# TODO: add streaming response handling
for tool in response.choices[0].message.tool_calls:
if tool.type not in tool_types:
done = True
continue
return wrapper

done = False
# call a tool
tool_name = tool.function.name
tool_args = json.loads(tool.function.arguments)
tool_response = await call_tool(developer_id, tool_name, tool_args)

# append result to messages from previous step
messages: list = kwargs.get("messages", [])
messages.append({
"tool_call_id": tool.id,
"role": "tool",
"name": tool_name,
"content": tool_response,
})
kwargs["messages"] = messages

return response
return decor
Loading

0 comments on commit 8e299d8

Please sign in to comment.