Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agents name validation and error handling #583

Merged
merged 11 commits into from
Jan 22, 2025
13 changes: 11 additions & 2 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ def __init__(
code_execution_config.copy() if hasattr(code_execution_config, "copy") else code_execution_config
)

self._validate_name(name)
self._name = name
# a dictionary of conversations, default value is list
if chat_messages is None:
self._oai_messages = defaultdict(list)
Expand All @@ -190,6 +188,8 @@ def __init__(
) from e

self._validate_llm_config(llm_config)
self._validate_name(name)
self._name = name

if logging_enabled():
log_new_agent(self, locals())
Expand Down Expand Up @@ -285,6 +285,15 @@ def __init__(
}

def _validate_name(self, name: str) -> None:
if not self.llm_config or "config_list" not in self.llm_config or len(self.llm_config["config_list"]) == 0:
return

config_list = self.llm_config.get("config_list")
# The validation is currently done only for openai endpoints
# (other ones do not have the issue with whitespace in the name)
if "api_type" in config_list[0] and config_list[0]["api_type"] != "openai":
return

# Validation for name using regex to detect any whitespace
if re.search(r"\s", name):
raise ValueError(f"The name of the agent cannot contain any whitespace. The name provided is: '{name}'")
Expand Down
30 changes: 30 additions & 0 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import inspect
import logging
import re
import sys
import uuid
import warnings
Expand Down Expand Up @@ -289,6 +290,33 @@ def _format_content(content: str) -> str:
for choice in choices
]

@staticmethod
def _is_agent_name_error_message(message: str) -> bool:
pattern = re.compile(r"Invalid 'messages\[\d+\]\.name': string does not match pattern.")
return True if pattern.match(message) else False

@staticmethod
def _handle_openai_bad_request_error(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any):
try:
return func(*args, **kwargs)
except openai.BadRequestError as e:
response_json = e.response.json()
# Check if the error message is related to the agent name. If so, raise a ValueError with a more informative message.
if "error" in response_json and "message" in response_json["error"]:
if OpenAIClient._is_agent_name_error_message(response_json["error"]["message"]):
error_message = (
f"This error typically occurs when the agent name contains invalid characters, such as spaces or special symbols.\n"
"Please ensure that your agent name follows the correct format and doesn't include any unsupported characters.\n"
"Check the agent name and try again.\n"
f"Here is the full BadRequestError from openai:\n{e.message}."
)
raise ValueError(error_message)

raise e

return wrapper

def create(self, params: dict[str, Any]) -> ChatCompletion:
"""Create a completion for a given config using openai's client.

Expand All @@ -313,6 +341,8 @@ def _create_or_parse(*args, **kwargs):
else:
completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
create_or_parse = completions.create
# Wrap _create_or_parse with exception handling
create_or_parse = OpenAIClient._handle_openai_bad_request_error(create_or_parse)

# needs to be updated when the o3 is released to generalize
is_o1 = "model" in params and params["model"].startswith("o1")
Expand Down
43 changes: 41 additions & 2 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,23 @@ def conversable_agent():


@pytest.mark.parametrize("name", ["agent name", "agent_name ", " agent\nname", " agent\tname"])
def test_conversable_agent_name_with_white_space_raises_error(name: str) -> None:
def test_conversable_agent_name_with_white_space(
name: str,
mock_credentials: Credentials,
) -> None:
agent = ConversableAgent(name=name)
assert agent.name == name

llm_config = mock_credentials.llm_config
with pytest.raises(
ValueError,
match=f"The name of the agent cannot contain any whitespace. The name provided is: '{name}'",
):
ConversableAgent(name=name)
ConversableAgent(name=name, llm_config=llm_config)

llm_config["config_list"][0]["api_type"] = "something-else"
agent = ConversableAgent(name=name, llm_config=llm_config)
assert agent.name == name


def test_sync_trigger():
Expand Down Expand Up @@ -1481,6 +1492,34 @@ def test_handle_carryover():
assert proc_content_empty_carryover == content, "Incorrect carryover processing"


@pytest.mark.parametrize("credentials_from_test_param", credentials_all_llms, indirect=True)
def test_conversable_agent_with_whitespaces_in_name_end2end(
credentials_from_test_param: Credentials,
request: pytest.FixtureRequest,
) -> None:
agent = ConversableAgent(
name="first_agent",
llm_config=credentials_from_test_param.llm_config,
)

user_proxy = UserProxyAgent(
name="user proxy",
human_input_mode="NEVER",
)

# Get the parameter name request node
current_llm = request.node.callspec.id
if "gpt_4" in current_llm:
with pytest.raises(
ValueError,
match="This error typically occurs when the agent name contains invalid characters, such as spaces or special symbols.",
):
user_proxy.initiate_chat(agent, message="Hello, how are you?", max_turns=2)
# anthropic and gemini will not raise an error if agent name contains whitespaces
else:
user_proxy.initiate_chat(agent, message="Hello, how are you?", max_turns=2)


@pytest.mark.openai
def test_context_variables():
# Test initialization with context_variables
Expand Down
52 changes: 52 additions & 0 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import shutil
import time
from collections.abc import Generator
from unittest.mock import MagicMock

import pytest

Expand Down Expand Up @@ -290,6 +291,57 @@ def test_cache(credentials_gpt_4o_mini: Credentials):
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))


class TestOpenAIClientBadRequestsError:
def test_is_agent_name_error_message(self) -> None:
assert OpenAIClient._is_agent_name_error_message("Invalid 'messages[0].something") is False
for i in range(5):
error_message = f"Invalid 'messages[{i}].name': string does not match pattern. Expected a string that matches the pattern ..."
assert OpenAIClient._is_agent_name_error_message(error_message) is True

@pytest.mark.parametrize(
"error_message, raise_new_error",
[
(
"Invalid 'messages[0].name': string does not match pattern. Expected a string that matches the pattern ...",
True,
),
(
"Invalid 'messages[1].name': string does not match pattern. Expected a string that matches the pattern ...",
True,
),
(
"Invalid 'messages[0].something': string does not match pattern. Expected a string that matches the pattern ...",
False,
),
],
)
def test_handle_openai_bad_request_error(self, error_message: str, raise_new_error: bool) -> None:
def raise_bad_request_error(error_message: str) -> None:
mock_response = MagicMock()
mock_response.json.return_value = {
"error": {
"message": error_message,
}
}
body = {"error": {"message": "Bad Request error occurred"}}
raise openai.BadRequestError("Bad Request", response=mock_response, body=body)

# Function raises BadRequestError
with pytest.raises(openai.BadRequestError):
raise_bad_request_error(error_message=error_message)

wrapped_raise_bad_request_error = OpenAIClient._handle_openai_bad_request_error(raise_bad_request_error)
if raise_new_error:
with pytest.raises(
ValueError,
match="This error typically occurs when the agent name contains invalid characters, such as spaces or special symbols.",
):
wrapped_raise_bad_request_error(error_message=error_message)
else:
with pytest.raises(openai.BadRequestError):
wrapped_raise_bad_request_error(error_message=error_message)


class TestO1:
@pytest.fixture
def mock_oai_client(self, mock_credentials: Credentials) -> OpenAIClient:
Expand Down
Loading