1
1
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
2
2
import json
3
- from dataclasses import dataclass
4
3
from json import JSONDecodeError
5
4
from typing import Any , List , Optional , Sequence , Tuple , Union
6
5
7
6
from langchain .agents import BaseSingleActionAgent
7
+ from langchain .agents .agent import AgentOutputParser
8
+ from langchain .agents .format_scratchpad .openai_functions import (
9
+ format_to_openai_functions ,
10
+ )
8
11
from langchain .callbacks .base import BaseCallbackManager
9
12
from langchain .callbacks .manager import Callbacks # type: ignore
10
13
from langchain .chat_models .openai import ChatOpenAI
18
21
from langchain .schema import (
19
22
AgentAction ,
20
23
AgentFinish ,
21
- BasePromptTemplate ,
22
- OutputParserException ,
23
- )
24
- from langchain .schema .language_model import BaseLanguageModel
25
- from langchain .schema .messages import (
26
24
AIMessage ,
27
25
BaseMessage ,
28
- FunctionMessage ,
26
+ BasePromptTemplate ,
27
+ OutputParserException ,
29
28
SystemMessage ,
30
29
)
31
- from langchain .tools import BaseTool
32
- from langchain .tools .convert_to_openai import format_tool_to_openai_function
33
-
34
-
35
- @dataclass
36
- class _FunctionsAgentAction (AgentAction ):
37
- message_log : List [BaseMessage ]
30
+ from langchain .schema .agent import AgentActionMessageLog
31
+ from langchain .schema .language_model import BaseLanguageModel
32
+ from langchain .schema .output import ChatGeneration , Generation
33
+ from langchain .tools .base import BaseTool
34
+ from langchain .tools .render import format_tool_to_openai_function
38
35
39
36
40
- def _convert_agent_action_to_messages (
41
- agent_action : AgentAction , observation : str
42
- ) -> List [BaseMessage ]:
43
- """Convert an agent action to a message.
37
+ class OpenAIFunctionsAgentOutputParser (AgentOutputParser ):
38
+ """Parses a message into agent action/finish.
44
39
45
- This code is used to reconstruct the original AI message from the agent action.
40
+ Is meant to be used with OpenAI models, as it relies on the specific
41
+ function_call parameter from OpenAI to convey what tools to use.
46
42
47
- Args:
48
- agent_action: Agent action to convert .
43
+ If a function_call parameter is passed, then that is used to get
44
+ the tool and tool input .
49
45
50
- Returns:
51
- AIMessage that corresponds to the original tool invocation.
52
- """
53
- if isinstance (agent_action , _FunctionsAgentAction ):
54
- return agent_action .message_log + [
55
- _create_function_message (agent_action , observation )
56
- ]
57
- else :
58
- return [AIMessage (content = agent_action .log )]
59
-
60
-
61
- def _create_function_message (
62
- agent_action : AgentAction , observation : str
63
- ) -> FunctionMessage :
64
- """Convert agent action and observation into a function message.
65
- Args:
66
- agent_action: the tool invocation request from the agent
67
- observation: the result of the tool invocation
68
- Returns:
69
- FunctionMessage that corresponds to the original tool invocation
70
- """
71
- if not isinstance (observation , str ):
72
- try :
73
- content = json .dumps (observation , ensure_ascii = False )
74
- except Exception :
75
- content = str (observation )
76
- else :
77
- content = observation
78
- return FunctionMessage (
79
- name = agent_action .tool ,
80
- content = content ,
81
- )
82
-
83
-
84
- def _format_intermediate_steps (
85
- intermediate_steps : List [Tuple [AgentAction , str ]],
86
- ) -> List [BaseMessage ]:
87
- """Format intermediate steps.
88
- Args:
89
- intermediate_steps: Steps the LLM has taken to date, along with observations
90
- Returns:
91
- list of messages to send to the LLM for the next prediction
46
+ If one is not passed, then the AIMessage is assumed to be the final output.
92
47
"""
93
- messages = []
94
48
95
- for intermediate_step in intermediate_steps :
96
- agent_action , observation = intermediate_step
97
- messages .extend (_convert_agent_action_to_messages (agent_action , observation ))
98
-
99
- return messages
100
-
101
-
102
- def _parse_ai_message (message : BaseMessage ) -> Union [AgentAction , AgentFinish ]:
103
- """Parse an AI message."""
104
- if not isinstance (message , AIMessage ):
105
- raise TypeError (f"Expected an AI message got { type (message )} " )
106
-
107
- function_call = message .additional_kwargs .get ("function_call" , {})
108
-
109
- if function_call :
110
- function_name = function_call ["name" ]
111
- try :
112
- _tool_input = json .loads (function_call ["arguments" ])
113
- except JSONDecodeError :
114
- if function_name == "python" :
115
- code = function_call ["arguments" ]
116
- _tool_input = {
117
- "code" : code ,
118
- }
49
+ @property
50
+ def _type (self ) -> str :
51
+ return "openai-functions-agent"
52
+
53
+ @staticmethod
54
+ def _parse_ai_message (message : BaseMessage ) -> Union [AgentAction , AgentFinish ]:
55
+ """Parse an AI message."""
56
+ if not isinstance (message , AIMessage ):
57
+ raise TypeError (f"Expected an AI message got { type (message )} " )
58
+
59
+ function_call = message .additional_kwargs .get ("function_call" , {})
60
+
61
+ if function_call :
62
+ function_name = function_call ["name" ]
63
+ try :
64
+ _tool_input = json .loads (function_call ["arguments" ])
65
+ except JSONDecodeError :
66
+ if function_name == "python" :
67
+ code = function_call ["arguments" ]
68
+ _tool_input = {
69
+ "code" : code ,
70
+ }
71
+ else :
72
+ raise OutputParserException (
73
+ f"Could not parse tool input: { function_call } because "
74
+ f"the `arguments` is not valid JSON."
75
+ )
76
+
77
+ # HACK HACK HACK:
78
+ # The code that encodes tool input into Open AI uses a special variable
79
+ # name called `__arg1` to handle old style tools that do not expose a
80
+ # schema and expect a single string argument as an input.
81
+ # We unpack the argument here if it exists.
82
+ # Open AI does not support passing in a JSON array as an argument.
83
+ if "__arg1" in _tool_input :
84
+ tool_input = _tool_input ["__arg1" ]
119
85
else :
120
- raise OutputParserException (
121
- f"Could not parse tool input: { function_call } because "
122
- f"the `arguments` is not valid JSON."
123
- )
124
-
125
- # HACK HACK HACK:
126
- # The code that encodes tool input into Open AI uses a special variable
127
- # name called `__arg1` to handle old style tools that do not expose a
128
- # schema and expect a single string argument as an input.
129
- # We unpack the argument here if it exists.
130
- # Open AI does not support passing in a JSON array as an argument.
131
- if "__arg1" in _tool_input :
132
- tool_input = _tool_input ["__arg1" ]
133
- else :
134
- tool_input = _tool_input
135
-
136
- content_msg = "responded: {content}\n " if message .content else "\n "
86
+ tool_input = _tool_input
87
+
88
+ content_msg = f"responded: { message .content } \n " if message .content else "\n "
89
+ log = f"\n Invoking: `{ function_name } ` with `{ tool_input } `\n { content_msg } \n "
90
+ return AgentActionMessageLog (
91
+ tool = function_name ,
92
+ tool_input = tool_input ,
93
+ log = log ,
94
+ message_log = [message ],
95
+ )
137
96
138
- return _FunctionsAgentAction (
139
- tool = function_name ,
140
- tool_input = tool_input ,
141
- log = f"\n Invoking: `{ function_name } ` with `{ tool_input } `\n { content_msg } \n " ,
142
- message_log = [message ],
97
+ return AgentFinish (
98
+ return_values = {"output" : message .content }, log = message .content
143
99
)
144
100
145
- return AgentFinish (return_values = {"output" : message .content }, log = message .content )
101
+ def parse_result (self , result : List [Generation ]) -> Union [AgentAction , AgentFinish ]:
102
+ if not isinstance (result [0 ], ChatGeneration ):
103
+ raise ValueError ("This output parser only works on ChatGeneration output" )
104
+ message = result [0 ].message
105
+ return self ._parse_ai_message (message )
106
+
107
+ def parse (self , text : str ) -> Union [AgentAction , AgentFinish ]:
108
+ raise ValueError ("Can only parse messages" )
146
109
147
110
148
111
class OpenAIFunctionsAgent (BaseSingleActionAgent ):
@@ -206,7 +169,7 @@ def plan(
206
169
Returns:
207
170
Action specifying what tool to use.
208
171
"""
209
- agent_scratchpad = _format_intermediate_steps (intermediate_steps )
172
+ agent_scratchpad = format_to_openai_functions (intermediate_steps )
210
173
selected_inputs = {
211
174
k : kwargs [k ] for k in self .prompt .input_variables if k != "agent_scratchpad"
212
175
}
@@ -224,7 +187,9 @@ def plan(
224
187
messages ,
225
188
callbacks = callbacks ,
226
189
)
227
- agent_decision = _parse_ai_message (predicted_message )
190
+ agent_decision = OpenAIFunctionsAgentOutputParser ._parse_ai_message (
191
+ predicted_message
192
+ )
228
193
return agent_decision
229
194
230
195
async def aplan (
@@ -243,7 +208,7 @@ async def aplan(
243
208
Returns:
244
209
Action specifying what tool to use.
245
210
"""
246
- agent_scratchpad = _format_intermediate_steps (intermediate_steps )
211
+ agent_scratchpad = format_to_openai_functions (intermediate_steps )
247
212
selected_inputs = {
248
213
k : kwargs [k ] for k in self .prompt .input_variables if k != "agent_scratchpad"
249
214
}
@@ -253,7 +218,9 @@ async def aplan(
253
218
predicted_message = await self .llm .apredict_messages (
254
219
messages , functions = self .functions , callbacks = callbacks
255
220
)
256
- agent_decision = _parse_ai_message (predicted_message )
221
+ agent_decision = OpenAIFunctionsAgentOutputParser ._parse_ai_message (
222
+ predicted_message
223
+ )
257
224
return agent_decision
258
225
259
226
def return_stopped_response (
@@ -339,7 +306,7 @@ def from_llm_and_tools(
339
306
extra_prompt_messages = extra_prompt_messages ,
340
307
system_message = system_message ,
341
308
)
342
- return cls ( # type: ignore
309
+ return cls (
343
310
llm = llm ,
344
311
prompt = prompt ,
345
312
tools = tools ,
0 commit comments