diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 1812c56600..75720f9a36 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -163,8 +163,6 @@ def __init__( code_execution_config.copy() if hasattr(code_execution_config, "copy") else code_execution_config ) - self._validate_name(name) - self._name = name # a dictionary of conversations, default value is list if chat_messages is None: self._oai_messages = defaultdict(list) @@ -190,6 +188,8 @@ def __init__( ) from e self._validate_llm_config(llm_config) + self._validate_name(name) + self._name = name if logging_enabled(): log_new_agent(self, locals()) @@ -285,6 +285,15 @@ def __init__( } def _validate_name(self, name: str) -> None: + if not self.llm_config or "config_list" not in self.llm_config or len(self.llm_config["config_list"]) == 0: + return + + config_list = self.llm_config.get("config_list") + # The validation is currently done only for openai endpoints + # (other ones do not have the issue with whitespace in the name) + if "api_type" in config_list[0] and config_list[0]["api_type"] != "openai": + return + # Validation for name using regex to detect any whitespace if re.search(r"\s", name): raise ValueError(f"The name of the agent cannot contain any whitespace. The name provided is: '{name}'") diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 1e6fdd1a1f..29c652c20c 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -8,6 +8,7 @@ import inspect import logging +import re import sys import uuid import warnings @@ -289,6 +290,33 @@ def _format_content(content: str) -> str: for choice in choices ] + @staticmethod + def _is_agent_name_error_message(message: str) -> bool: + pattern = re.compile(r"Invalid 'messages\[\d+\]\.name': string does not match pattern.") + return True if pattern.match(message) else False + + @staticmethod + def _handle_openai_bad_request_error(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any): + try: + return func(*args, **kwargs) + except openai.BadRequestError as e: + response_json = e.response.json() + # Check if the error message is related to the agent name. If so, raise a ValueError with a more informative message. + if "error" in response_json and "message" in response_json["error"]: + if OpenAIClient._is_agent_name_error_message(response_json["error"]["message"]): + error_message = ( + f"This error typically occurs when the agent name contains invalid characters, such as spaces or special symbols.\n" + "Please ensure that your agent name follows the correct format and doesn't include any unsupported characters.\n" + "Check the agent name and try again.\n" + f"Here is the full BadRequestError from openai:\n{e.message}." + ) + raise ValueError(error_message) + + raise e + + return wrapper + def create(self, params: dict[str, Any]) -> ChatCompletion: """Create a completion for a given config using openai's client. @@ -313,6 +341,8 @@ def _create_or_parse(*args, **kwargs): else: completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined] create_or_parse = completions.create + # Wrap _create_or_parse with exception handling + create_or_parse = OpenAIClient._handle_openai_bad_request_error(create_or_parse) # needs to be updated when the o3 is released to generalize is_o1 = "model" in params and params["model"].startswith("o1") diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index 2430df22d8..c343140243 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -40,12 +40,23 @@ def conversable_agent(): @pytest.mark.parametrize("name", ["agent name", "agent_name ", " agent\nname", " agent\tname"]) -def test_conversable_agent_name_with_white_space_raises_error(name: str) -> None: +def test_conversable_agent_name_with_white_space( + name: str, + mock_credentials: Credentials, +) -> None: + agent = ConversableAgent(name=name) + assert agent.name == name + + llm_config = mock_credentials.llm_config with pytest.raises( ValueError, match=f"The name of the agent cannot contain any whitespace. The name provided is: '{name}'", ): - ConversableAgent(name=name) + ConversableAgent(name=name, llm_config=llm_config) + + llm_config["config_list"][0]["api_type"] = "something-else" + agent = ConversableAgent(name=name, llm_config=llm_config) + assert agent.name == name def test_sync_trigger(): @@ -1481,6 +1492,34 @@ def test_handle_carryover(): assert proc_content_empty_carryover == content, "Incorrect carryover processing" +@pytest.mark.parametrize("credentials_from_test_param", credentials_all_llms, indirect=True) +def test_conversable_agent_with_whitespaces_in_name_end2end( + credentials_from_test_param: Credentials, + request: pytest.FixtureRequest, +) -> None: + agent = ConversableAgent( + name="first_agent", + llm_config=credentials_from_test_param.llm_config, + ) + + user_proxy = UserProxyAgent( + name="user proxy", + human_input_mode="NEVER", + ) + + # Get the parameter name request node + current_llm = request.node.callspec.id + if "gpt_4" in current_llm: + with pytest.raises( + ValueError, + match="This error typically occurs when the agent name contains invalid characters, such as spaces or special symbols.", + ): + user_proxy.initiate_chat(agent, message="Hello, how are you?", max_turns=2) + # anthropic and gemini will not raise an error if agent name contains whitespaces + else: + user_proxy.initiate_chat(agent, message="Hello, how are you?", max_turns=2) + + @pytest.mark.openai def test_context_variables(): # Test initialization with context_variables diff --git a/test/oai/test_client.py b/test/oai/test_client.py index 5f8217c18f..774be3c335 100755 --- a/test/oai/test_client.py +++ b/test/oai/test_client.py @@ -10,6 +10,7 @@ import shutil import time from collections.abc import Generator +from unittest.mock import MagicMock import pytest @@ -290,6 +291,57 @@ def test_cache(credentials_gpt_4o_mini: Credentials): assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED))) +class TestOpenAIClientBadRequestsError: + def test_is_agent_name_error_message(self) -> None: + assert OpenAIClient._is_agent_name_error_message("Invalid 'messages[0].something") is False + for i in range(5): + error_message = f"Invalid 'messages[{i}].name': string does not match pattern. Expected a string that matches the pattern ..." + assert OpenAIClient._is_agent_name_error_message(error_message) is True + + @pytest.mark.parametrize( + "error_message, raise_new_error", + [ + ( + "Invalid 'messages[0].name': string does not match pattern. Expected a string that matches the pattern ...", + True, + ), + ( + "Invalid 'messages[1].name': string does not match pattern. Expected a string that matches the pattern ...", + True, + ), + ( + "Invalid 'messages[0].something': string does not match pattern. Expected a string that matches the pattern ...", + False, + ), + ], + ) + def test_handle_openai_bad_request_error(self, error_message: str, raise_new_error: bool) -> None: + def raise_bad_request_error(error_message: str) -> None: + mock_response = MagicMock() + mock_response.json.return_value = { + "error": { + "message": error_message, + } + } + body = {"error": {"message": "Bad Request error occurred"}} + raise openai.BadRequestError("Bad Request", response=mock_response, body=body) + + # Function raises BadRequestError + with pytest.raises(openai.BadRequestError): + raise_bad_request_error(error_message=error_message) + + wrapped_raise_bad_request_error = OpenAIClient._handle_openai_bad_request_error(raise_bad_request_error) + if raise_new_error: + with pytest.raises( + ValueError, + match="This error typically occurs when the agent name contains invalid characters, such as spaces or special symbols.", + ): + wrapped_raise_bad_request_error(error_message=error_message) + else: + with pytest.raises(openai.BadRequestError): + wrapped_raise_bad_request_error(error_message=error_message) + + class TestO1: @pytest.fixture def mock_oai_client(self, mock_credentials: Credentials) -> OpenAIClient: