Skip to content

Commit

Permalink
fix(agents-api): Remove anthropic client from chat & other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
HamadaSalhab committed Nov 20, 2024
1 parent 98c037e commit 0fe1cec
Showing 1 changed file with 33 additions and 149 deletions.
182 changes: 33 additions & 149 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
@@ -1,152 +1,28 @@
import json
from datetime import datetime
from typing import Annotated, Callable, Optional
from typing import Annotated, Optional
from uuid import UUID, uuid4

from anthropic import AsyncAnthropic
from anthropic.types.beta.beta_message import BetaMessage
from fastapi import BackgroundTasks, Depends, Header
from litellm import ChatCompletionMessageToolCall, Function, Message
from litellm.types.utils import Choices, ModelResponse
from starlette.status import HTTP_201_CREATED

from ...activities.task_steps.prompt_step import format_tool
from ...autogen.openapi_model import (
ChatInput,
ChatResponse,
ChunkChatResponse,
CreateEntryRequest,
MessageChatResponse,
)
from ...autogen.Tools import Tool
from ...clients import litellm
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ...common.utils.datetime import utcnow
from ...common.utils.template import render_template
from ...dependencies.developer_id import get_developer_data
from ...env import anthropic_api_key
from ...models.chat.gather_messages import gather_messages
from ...models.chat.prepare_chat_context import prepare_chat_context
from ...models.entry.create_entries import create_entries
from .metrics import total_tokens_per_user
from .router import router

COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"


async def request_anthropic(
messages: list[dict], formatted_tools: list[dict], settings: dict
) -> ModelResponse:
# Use Anthropic API directly
client = AsyncAnthropic(api_key=anthropic_api_key)

# Filter tools for specific types
filtered_tools = [
tool
for tool in formatted_tools
if tool["type"]
in ["computer_20241022", "bash_20241022", "text_editor_20241022"]
]

# Format messages for Claude
claude_messages = []
for msg in messages:
# Skip messages that are not assistant or user
if msg["role"] not in ["assistant", "user"]:
continue

# FIXME: return the tool call ids (save assistant message in entries as json dump)
# Transform the message content and tool calls
if msg["role"] == "assistant":
transformed_content = [
{
"text": "Let's do this action"
if msg["content"] == []
else msg["content"],
"type": "text",
}
]
transformed_content.extend(
{
"id": f"{tool_call['id']}",
"input": json.loads(tool_call["function"]["arguments"]),
"name": tool_call["function"]["name"],
"type": "tool_use",
}
for tool_call in msg.get("tool_calls", [])
)
claude_message = {
"role": msg["role"],
"content": transformed_content,
}
elif msg["role"] == "user":
try:
transformed_content = json.loads(msg["content"])
except Exception:
transformed_content = msg["content"]
claude_message = {"role": msg["role"], "content": transformed_content}

claude_messages.append(claude_message)
# Call Claude API
claude_response: BetaMessage = await client.beta.messages.create(
model="claude-3-5-sonnet-20241022",
messages=claude_messages,
tools=filtered_tools,
max_tokens=settings.get("max_tokens", 1024),
betas=[COMPUTER_USE_BETA_FLAG],
)
# Convert Claude response to litellm format
text_block = next(
(block for block in claude_response.content if block.type == "text"),
None,
)

if claude_response.stop_reason == "tool_use":
choice = Choices(
message=Message(
role="assistant",
content=text_block.text if text_block else None,
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
function=Function(
name=block.name,
arguments=block.input,
),
)
for block in claude_response.content
if block.type == "tool_use"
],
),
finish_reason="tool_calls",
)
else:
assert (
text_block
), "Claude should always return a text block for stop_reason=stop"
choice = Choices(
message=Message(
role="assistant",
content=text_block.text,
),
finish_reason="stop",
)

model_response = ModelResponse(
id=claude_response.id,
choices=[choice],
created=int(datetime.now().timestamp()),
model=claude_response.model,
object="text_completion",
usage={
"total_tokens": claude_response.usage.input_tokens
+ claude_response.usage.output_tokens
},
)

return model_response


@router.post(
"/sessions/{session_id}/chat",
Expand Down Expand Up @@ -185,7 +61,8 @@ async def chat(

# Merge the settings and prepare environment
chat_context.merge_settings(chat_input)
settings: dict = chat_context.settings.model_dump(mode="json", exclude_none=True)
settings: dict = chat_context.settings.model_dump(
mode="json", exclude_none=True)

# Get the past messages and doc references
past_messages, doc_references = await gather_messages(
Expand Down Expand Up @@ -218,7 +95,8 @@ async def chat(
past_messages = system_messages + past_messages

# Render the incoming messages
new_raw_messages = [msg.model_dump(mode="json") for msg in chat_input.messages]
new_raw_messages = [msg.model_dump(mode="json")
for msg in chat_input.messages]

if chat_context.session.render_templates:
new_messages = await render_template(new_raw_messages, variables=env)
Expand All @@ -230,9 +108,27 @@ async def chat(

# Get the tools
tools = settings.get("tools") or chat_context.get_active_tools()
tools = [tool.model_dump(mode="json") for tool in tools]

# Convert anthropic tools to `function`
for tool in tools:
if tool.get("type") == "computer_20241022":
tool["function"] = {
"name": tool["name"],
"parameters": tool.pop("computer_20241022"),
}

# Format tools for litellm
formatted_tools = [format_tool(tool) for tool in tools]
elif tool.get("type") == "bash_20241022":
tool["function"] = {
"name": tool["name"],
"parameters": tool.pop("bash_20241022"),
}

elif tool.get("type") == "text_editor_20241022":
tool["function"] = {
"name": tool["name"],
"parameters": tool.pop("text_editor_20241022"),
}

# FIXME: Truncate chat messages in the chat context
# SCRUM-7
Expand All @@ -250,28 +146,16 @@ async def chat(
for m in messages
]

# Check if using Claude model and has specific tool types
is_claude_model = settings["model"].lower().startswith("claude-3.5")
has_special_tools = any(
tool["type"] in ["computer_20241022", "bash_20241022", "text_editor_20241022"]
for tool in formatted_tools
# Get the response from the model
model_response = await litellm.acompletion(
messages=messages,
tools=tools or None,
user=str(developer.id), # For tracking usage
tags=developer.tags, # For filtering models in litellm
custom_api_key=x_custom_api_key,
**settings,
)

if is_claude_model and has_special_tools:
model_response = await request_anthropic(messages, formatted_tools, settings)
else:
# FIXME: hardcoded tool to a None value as the tool calls are not implemented yet
formatted_tools = None
# Use litellm for other models
model_response = await litellm.acompletion(
messages=messages,
tools=formatted_tools or None,
user=str(developer.id),
tags=developer.tags,
custom_api_key=x_custom_api_key,
**settings,
)

# Save the input and the response to the session history
if chat_input.save:
new_entries = [
Expand Down

0 comments on commit 0fe1cec

Please sign in to comment.