From 0fe1cecce07046141ff00b10604175ca8f3ccde7 Mon Sep 17 00:00:00 2001 From: HamadaSalhab Date: Wed, 20 Nov 2024 13:22:12 +0300 Subject: [PATCH] fix(agents-api): Remove anthropic client from chat & other fixes --- .../agents_api/routers/sessions/chat.py | 182 ++++-------------- 1 file changed, 33 insertions(+), 149 deletions(-) diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index f6e7807b4..206c972b4 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -1,16 +1,9 @@ -import json -from datetime import datetime -from typing import Annotated, Callable, Optional +from typing import Annotated, Optional from uuid import UUID, uuid4 -from anthropic import AsyncAnthropic -from anthropic.types.beta.beta_message import BetaMessage from fastapi import BackgroundTasks, Depends, Header -from litellm import ChatCompletionMessageToolCall, Function, Message -from litellm.types.utils import Choices, ModelResponse from starlette.status import HTTP_201_CREATED -from ...activities.task_steps.prompt_step import format_tool from ...autogen.openapi_model import ( ChatInput, ChatResponse, @@ -18,135 +11,18 @@ CreateEntryRequest, MessageChatResponse, ) -from ...autogen.Tools import Tool from ...clients import litellm from ...common.protocol.developers import Developer from ...common.protocol.sessions import ChatContext from ...common.utils.datetime import utcnow from ...common.utils.template import render_template from ...dependencies.developer_id import get_developer_data -from ...env import anthropic_api_key from ...models.chat.gather_messages import gather_messages from ...models.chat.prepare_chat_context import prepare_chat_context from ...models.entry.create_entries import create_entries from .metrics import total_tokens_per_user from .router import router -COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" - - -async def request_anthropic( - messages: list[dict], formatted_tools: list[dict], settings: dict -) -> ModelResponse: - # Use Anthropic API directly - client = AsyncAnthropic(api_key=anthropic_api_key) - - # Filter tools for specific types - filtered_tools = [ - tool - for tool in formatted_tools - if tool["type"] - in ["computer_20241022", "bash_20241022", "text_editor_20241022"] - ] - - # Format messages for Claude - claude_messages = [] - for msg in messages: - # Skip messages that are not assistant or user - if msg["role"] not in ["assistant", "user"]: - continue - - # FIXME: return the tool call ids (save assistant message in entries as json dump) - # Transform the message content and tool calls - if msg["role"] == "assistant": - transformed_content = [ - { - "text": "Let's do this action" - if msg["content"] == [] - else msg["content"], - "type": "text", - } - ] - transformed_content.extend( - { - "id": f"{tool_call['id']}", - "input": json.loads(tool_call["function"]["arguments"]), - "name": tool_call["function"]["name"], - "type": "tool_use", - } - for tool_call in msg.get("tool_calls", []) - ) - claude_message = { - "role": msg["role"], - "content": transformed_content, - } - elif msg["role"] == "user": - try: - transformed_content = json.loads(msg["content"]) - except Exception: - transformed_content = msg["content"] - claude_message = {"role": msg["role"], "content": transformed_content} - - claude_messages.append(claude_message) - # Call Claude API - claude_response: BetaMessage = await client.beta.messages.create( - model="claude-3-5-sonnet-20241022", - messages=claude_messages, - tools=filtered_tools, - max_tokens=settings.get("max_tokens", 1024), - betas=[COMPUTER_USE_BETA_FLAG], - ) - # Convert Claude response to litellm format - text_block = next( - (block for block in claude_response.content if block.type == "text"), - None, - ) - - if claude_response.stop_reason == "tool_use": - choice = Choices( - message=Message( - role="assistant", - content=text_block.text if text_block else None, - tool_calls=[ - ChatCompletionMessageToolCall( - type="function", - function=Function( - name=block.name, - arguments=block.input, - ), - ) - for block in claude_response.content - if block.type == "tool_use" - ], - ), - finish_reason="tool_calls", - ) - else: - assert ( - text_block - ), "Claude should always return a text block for stop_reason=stop" - choice = Choices( - message=Message( - role="assistant", - content=text_block.text, - ), - finish_reason="stop", - ) - - model_response = ModelResponse( - id=claude_response.id, - choices=[choice], - created=int(datetime.now().timestamp()), - model=claude_response.model, - object="text_completion", - usage={ - "total_tokens": claude_response.usage.input_tokens - + claude_response.usage.output_tokens - }, - ) - - return model_response - @router.post( "/sessions/{session_id}/chat", @@ -185,7 +61,8 @@ async def chat( # Merge the settings and prepare environment chat_context.merge_settings(chat_input) - settings: dict = chat_context.settings.model_dump(mode="json", exclude_none=True) + settings: dict = chat_context.settings.model_dump( + mode="json", exclude_none=True) # Get the past messages and doc references past_messages, doc_references = await gather_messages( @@ -218,7 +95,8 @@ async def chat( past_messages = system_messages + past_messages # Render the incoming messages - new_raw_messages = [msg.model_dump(mode="json") for msg in chat_input.messages] + new_raw_messages = [msg.model_dump(mode="json") + for msg in chat_input.messages] if chat_context.session.render_templates: new_messages = await render_template(new_raw_messages, variables=env) @@ -230,9 +108,27 @@ async def chat( # Get the tools tools = settings.get("tools") or chat_context.get_active_tools() + tools = [tool.model_dump(mode="json") for tool in tools] + + # Convert anthropic tools to `function` + for tool in tools: + if tool.get("type") == "computer_20241022": + tool["function"] = { + "name": tool["name"], + "parameters": tool.pop("computer_20241022"), + } - # Format tools for litellm - formatted_tools = [format_tool(tool) for tool in tools] + elif tool.get("type") == "bash_20241022": + tool["function"] = { + "name": tool["name"], + "parameters": tool.pop("bash_20241022"), + } + + elif tool.get("type") == "text_editor_20241022": + tool["function"] = { + "name": tool["name"], + "parameters": tool.pop("text_editor_20241022"), + } # FIXME: Truncate chat messages in the chat context # SCRUM-7 @@ -250,28 +146,16 @@ async def chat( for m in messages ] - # Check if using Claude model and has specific tool types - is_claude_model = settings["model"].lower().startswith("claude-3.5") - has_special_tools = any( - tool["type"] in ["computer_20241022", "bash_20241022", "text_editor_20241022"] - for tool in formatted_tools + # Get the response from the model + model_response = await litellm.acompletion( + messages=messages, + tools=tools or None, + user=str(developer.id), # For tracking usage + tags=developer.tags, # For filtering models in litellm + custom_api_key=x_custom_api_key, + **settings, ) - if is_claude_model and has_special_tools: - model_response = await request_anthropic(messages, formatted_tools, settings) - else: - # FIXME: hardcoded tool to a None value as the tool calls are not implemented yet - formatted_tools = None - # Use litellm for other models - model_response = await litellm.acompletion( - messages=messages, - tools=formatted_tools or None, - user=str(developer.id), - tags=developer.tags, - custom_api_key=x_custom_api_key, - **settings, - ) - # Save the input and the response to the session history if chat_input.save: new_entries = [