Skip to content

Commit c290aa2

Browse files
committed
🔨 fix tool error
1 parent 4f630cd commit c290aa2

File tree

1 file changed

+83
-116
lines changed

1 file changed

+83
-116
lines changed

codeinterpreterapi/agents/functions_agent.py

+83-116
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
22
import json
3-
from dataclasses import dataclass
43
from json import JSONDecodeError
54
from typing import Any, List, Optional, Sequence, Tuple, Union
65

76
from langchain.agents import BaseSingleActionAgent
7+
from langchain.agents.agent import AgentOutputParser
8+
from langchain.agents.format_scratchpad.openai_functions import (
9+
format_to_openai_functions,
10+
)
811
from langchain.callbacks.base import BaseCallbackManager
912
from langchain.callbacks.manager import Callbacks # type: ignore
1013
from langchain.chat_models.openai import ChatOpenAI
@@ -18,131 +21,91 @@
1821
from langchain.schema import (
1922
AgentAction,
2023
AgentFinish,
21-
BasePromptTemplate,
22-
OutputParserException,
23-
)
24-
from langchain.schema.language_model import BaseLanguageModel
25-
from langchain.schema.messages import (
2624
AIMessage,
2725
BaseMessage,
28-
FunctionMessage,
26+
BasePromptTemplate,
27+
OutputParserException,
2928
SystemMessage,
3029
)
31-
from langchain.tools import BaseTool
32-
from langchain.tools.convert_to_openai import format_tool_to_openai_function
33-
34-
35-
@dataclass
36-
class _FunctionsAgentAction(AgentAction):
37-
message_log: List[BaseMessage]
30+
from langchain.schema.agent import AgentActionMessageLog
31+
from langchain.schema.language_model import BaseLanguageModel
32+
from langchain.schema.output import ChatGeneration, Generation
33+
from langchain.tools.base import BaseTool
34+
from langchain.tools.render import format_tool_to_openai_function
3835

3936

40-
def _convert_agent_action_to_messages(
41-
agent_action: AgentAction, observation: str
42-
) -> List[BaseMessage]:
43-
"""Convert an agent action to a message.
37+
class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
38+
"""Parses a message into agent action/finish.
4439
45-
This code is used to reconstruct the original AI message from the agent action.
40+
Is meant to be used with OpenAI models, as it relies on the specific
41+
function_call parameter from OpenAI to convey what tools to use.
4642
47-
Args:
48-
agent_action: Agent action to convert.
43+
If a function_call parameter is passed, then that is used to get
44+
the tool and tool input.
4945
50-
Returns:
51-
AIMessage that corresponds to the original tool invocation.
52-
"""
53-
if isinstance(agent_action, _FunctionsAgentAction):
54-
return agent_action.message_log + [
55-
_create_function_message(agent_action, observation)
56-
]
57-
else:
58-
return [AIMessage(content=agent_action.log)]
59-
60-
61-
def _create_function_message(
62-
agent_action: AgentAction, observation: str
63-
) -> FunctionMessage:
64-
"""Convert agent action and observation into a function message.
65-
Args:
66-
agent_action: the tool invocation request from the agent
67-
observation: the result of the tool invocation
68-
Returns:
69-
FunctionMessage that corresponds to the original tool invocation
70-
"""
71-
if not isinstance(observation, str):
72-
try:
73-
content = json.dumps(observation, ensure_ascii=False)
74-
except Exception:
75-
content = str(observation)
76-
else:
77-
content = observation
78-
return FunctionMessage(
79-
name=agent_action.tool,
80-
content=content,
81-
)
82-
83-
84-
def _format_intermediate_steps(
85-
intermediate_steps: List[Tuple[AgentAction, str]],
86-
) -> List[BaseMessage]:
87-
"""Format intermediate steps.
88-
Args:
89-
intermediate_steps: Steps the LLM has taken to date, along with observations
90-
Returns:
91-
list of messages to send to the LLM for the next prediction
46+
If one is not passed, then the AIMessage is assumed to be the final output.
9247
"""
93-
messages = []
9448

95-
for intermediate_step in intermediate_steps:
96-
agent_action, observation = intermediate_step
97-
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
98-
99-
return messages
100-
101-
102-
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
103-
"""Parse an AI message."""
104-
if not isinstance(message, AIMessage):
105-
raise TypeError(f"Expected an AI message got {type(message)}")
106-
107-
function_call = message.additional_kwargs.get("function_call", {})
108-
109-
if function_call:
110-
function_name = function_call["name"]
111-
try:
112-
_tool_input = json.loads(function_call["arguments"])
113-
except JSONDecodeError:
114-
if function_name == "python":
115-
code = function_call["arguments"]
116-
_tool_input = {
117-
"code": code,
118-
}
49+
@property
50+
def _type(self) -> str:
51+
return "openai-functions-agent"
52+
53+
@staticmethod
54+
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
55+
"""Parse an AI message."""
56+
if not isinstance(message, AIMessage):
57+
raise TypeError(f"Expected an AI message got {type(message)}")
58+
59+
function_call = message.additional_kwargs.get("function_call", {})
60+
61+
if function_call:
62+
function_name = function_call["name"]
63+
try:
64+
_tool_input = json.loads(function_call["arguments"])
65+
except JSONDecodeError:
66+
if function_name == "python":
67+
code = function_call["arguments"]
68+
_tool_input = {
69+
"code": code,
70+
}
71+
else:
72+
raise OutputParserException(
73+
f"Could not parse tool input: {function_call} because "
74+
f"the `arguments` is not valid JSON."
75+
)
76+
77+
# HACK HACK HACK:
78+
# The code that encodes tool input into Open AI uses a special variable
79+
# name called `__arg1` to handle old style tools that do not expose a
80+
# schema and expect a single string argument as an input.
81+
# We unpack the argument here if it exists.
82+
# Open AI does not support passing in a JSON array as an argument.
83+
if "__arg1" in _tool_input:
84+
tool_input = _tool_input["__arg1"]
11985
else:
120-
raise OutputParserException(
121-
f"Could not parse tool input: {function_call} because "
122-
f"the `arguments` is not valid JSON."
123-
)
124-
125-
# HACK HACK HACK:
126-
# The code that encodes tool input into Open AI uses a special variable
127-
# name called `__arg1` to handle old style tools that do not expose a
128-
# schema and expect a single string argument as an input.
129-
# We unpack the argument here if it exists.
130-
# Open AI does not support passing in a JSON array as an argument.
131-
if "__arg1" in _tool_input:
132-
tool_input = _tool_input["__arg1"]
133-
else:
134-
tool_input = _tool_input
135-
136-
content_msg = "responded: {content}\n" if message.content else "\n"
86+
tool_input = _tool_input
87+
88+
content_msg = f"responded: {message.content}\n" if message.content else "\n"
89+
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
90+
return AgentActionMessageLog(
91+
tool=function_name,
92+
tool_input=tool_input,
93+
log=log,
94+
message_log=[message],
95+
)
13796

138-
return _FunctionsAgentAction(
139-
tool=function_name,
140-
tool_input=tool_input,
141-
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n",
142-
message_log=[message],
97+
return AgentFinish(
98+
return_values={"output": message.content}, log=message.content
14399
)
144100

145-
return AgentFinish(return_values={"output": message.content}, log=message.content)
101+
def parse_result(self, result: List[Generation]) -> Union[AgentAction, AgentFinish]:
102+
if not isinstance(result[0], ChatGeneration):
103+
raise ValueError("This output parser only works on ChatGeneration output")
104+
message = result[0].message
105+
return self._parse_ai_message(message)
106+
107+
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
108+
raise ValueError("Can only parse messages")
146109

147110

148111
class OpenAIFunctionsAgent(BaseSingleActionAgent):
@@ -206,7 +169,7 @@ def plan(
206169
Returns:
207170
Action specifying what tool to use.
208171
"""
209-
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
172+
agent_scratchpad = format_to_openai_functions(intermediate_steps)
210173
selected_inputs = {
211174
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
212175
}
@@ -224,7 +187,9 @@ def plan(
224187
messages,
225188
callbacks=callbacks,
226189
)
227-
agent_decision = _parse_ai_message(predicted_message)
190+
agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message(
191+
predicted_message
192+
)
228193
return agent_decision
229194

230195
async def aplan(
@@ -243,7 +208,7 @@ async def aplan(
243208
Returns:
244209
Action specifying what tool to use.
245210
"""
246-
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
211+
agent_scratchpad = format_to_openai_functions(intermediate_steps)
247212
selected_inputs = {
248213
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
249214
}
@@ -253,7 +218,9 @@ async def aplan(
253218
predicted_message = await self.llm.apredict_messages(
254219
messages, functions=self.functions, callbacks=callbacks
255220
)
256-
agent_decision = _parse_ai_message(predicted_message)
221+
agent_decision = OpenAIFunctionsAgentOutputParser._parse_ai_message(
222+
predicted_message
223+
)
257224
return agent_decision
258225

259226
def return_stopped_response(
@@ -339,7 +306,7 @@ def from_llm_and_tools(
339306
extra_prompt_messages=extra_prompt_messages,
340307
system_message=system_message,
341308
)
342-
return cls( # type: ignore
309+
return cls(
343310
llm=llm,
344311
prompt=prompt,
345312
tools=tools,

0 commit comments

Comments
 (0)