From 61ab49b419351996bee9b85fd339c9ab8fd17b9c Mon Sep 17 00:00:00 2001 From: Robert Jambrecic Date: Fri, 21 Feb 2025 15:48:49 +0100 Subject: [PATCH] Preserve chat history when using 'run' method --- autogen/agentchat/conversable_agent.py | 24 +++++++------- test/agentchat/test_conversable_agent.py | 41 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 3465bcff86..a14654abc8 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -242,6 +242,7 @@ def __init__( else (lambda x: content_str(x.get("content")) == "TERMINATE") ) self.silent = silent + self.run_executor: Optional[ConversableAgent] = None # Take a copy to avoid modifying the given dict if isinstance(llm_config, dict): @@ -3301,7 +3302,7 @@ def get_total_usage(self) -> Union[None, dict[str, int]]: return self.client.total_usage_summary @contextmanager - def _create_executor( + def _create_or_get_executor( self, executor_kwargs: Optional[dict[str, Any]] = None, tools: Optional[Union[Tool, Iterable[Tool]]] = None, @@ -3323,19 +3324,20 @@ def _create_executor( if "is_termination_msg" not in executor_kwargs: executor_kwargs["is_termination_msg"] = lambda x: (x["content"] is not None) and "TERMINATE" in x["content"] - executor = ConversableAgent( - name=agent_name, - human_input_mode=agent_human_input_mode, - **executor_kwargs, - ) - try: + if not self.run_executor: + self.run_executor = ConversableAgent( + name=agent_name, + human_input_mode=agent_human_input_mode, + **executor_kwargs, + ) + tools = [] if tools is None else tools tools = [tools] if isinstance(tools, Tool) else tools for tool in tools: - tool.register_for_execution(executor) + tool.register_for_execution(self.run_executor) tool.register_for_llm(self) - yield executor + yield self.run_executor finally: if tools is not None: for tool in tools: @@ -3368,7 +3370,7 @@ def run( clear_history: whether to clear the chat history. user_input: the user will be asked for input at their turn. """ - with self._create_executor( + with self._create_or_get_executor( executor_kwargs=executor_kwargs, tools=tools, agent_name="user", @@ -3418,7 +3420,7 @@ async def a_run( clear_history: whether to clear the chat history. user_input: the user will be asked for input at their turn. """ - with self._create_executor( + with self._create_or_get_executor( executor_kwargs=executor_kwargs, tools=tools, agent_name="user", diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index b8541e985a..90765b77f0 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -1788,6 +1788,47 @@ def sample_tool_func(my_prop: str) -> str: assert "tool2" in tool_schemas +def test_create_or_get_executor(mock_credentials: Credentials): + agent = ConversableAgent(name="agent", llm_config=mock_credentials.llm_config) + executor_agent = None + + def log_result(result: str): + return "I have logged the result." + + tool = Tool( + name="log_result", + description="Logs the result of the task.", + func_or_tool=log_result, + ) + + expected_tools = [ + { + "type": "function", + "function": { + "description": "Logs the result of the task.", + "name": "log_result", + "parameters": { + "type": "object", + "properties": {"result": {"type": "string", "description": "result"}}, + "required": ["result"], + }, + }, + } + ] + executor_agent = None + for _ in range(2): + with agent._create_or_get_executor( + tools=[tool], + ) as executor: + if not executor_agent: + executor_agent = executor + else: + assert executor_agent == executor + assert isinstance(executor_agent, ConversableAgent) + assert agent.llm_config["tools"] == expected_tools + assert len(executor_agent.function_map.keys()) == 1 + + if __name__ == "__main__": # test_trigger() # test_context()