Skip to content

Commit

Permalink
Refactor LLM tests to use parameterized fixtures with dynamic marking (
Browse files Browse the repository at this point in the history
…#530)

* Refactor LLM tests

* Fix tests

* polishing

* polishing

---------

Co-authored-by: Davor Runje <[email protected]>
  • Loading branch information
rjambrecic and davorrunje authored Jan 17, 2025
1 parent e40436e commit af1e565
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 156 deletions.
59 changes: 31 additions & 28 deletions test/agentchat/test_agent_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import json
import sqlite3
import uuid
from typing import Any, Generator, Optional

import pytest
from _pytest.mark import ParameterSet

import autogen
import autogen.runtime_logging

from ..conftest import Credentials
from ..conftest import Credentials, credentials_all_llms

TEACHER_MESSAGE = """
You are roleplaying a math teacher, and your job is to help your students with linear algebra.
Expand Down Expand Up @@ -41,7 +43,7 @@


@pytest.fixture(scope="function")
def db_connection():
def db_connection() -> Generator[Optional[sqlite3.Connection], Any, None]:
autogen.runtime_logging.start(config={"dbname": ":memory:"})
con = autogen.runtime_logging.get_connection()
con.row_factory = sqlite3.Row
Expand All @@ -50,7 +52,9 @@ def db_connection():
autogen.runtime_logging.stop()


def _test_two_agents_logging(credentials: Credentials, db_connection, row_classes=["AzureOpenAI", "OpenAI"]) -> None:
def _test_two_agents_logging(
credentials: Credentials, db_connection: Generator[Optional[sqlite3.Connection], Any, None], row_classes: list[str]
) -> None:
cur = db_connection.cursor()

teacher = autogen.AssistantAgent(
Expand Down Expand Up @@ -169,19 +173,25 @@ def _test_two_agents_logging(credentials: Credentials, db_connection, row_classe
assert row["timestamp"], "timestamp is empty"


@pytest.mark.gemini
def test_two_agents_logging_gemini(credentials_gemini_pro: Credentials, db_connection) -> None:
_test_two_agents_logging(credentials_gemini_pro, db_connection, row_classes=["GeminiClient"])
@pytest.mark.parametrize("credentials_fixture", credentials_all_llms)
def test_two_agents_logging(
credentials_fixture: ParameterSet,
request: pytest.FixtureRequest,
db_connection: Generator[Optional[sqlite3.Connection], Any, None],
) -> None:
credentials = request.getfixturevalue(credentials_fixture)
# Determine the client classes based on the markers applied to the current test
applied_markers = [mark.name for mark in request.node.iter_markers()]
if "gemini" in applied_markers:
row_classes = ["GeminiClient"]
elif "anthropic" in applied_markers:
row_classes = ["AnthropicClient"]
elif "openai" in applied_markers:
row_classes = ["AzureOpenAI", "OpenAI"]
else:
raise ValueError("Unknown client class")


@pytest.mark.anthropic
def test_two_agents_logging_anthropic(credentials_anthropic_claude_sonnet: Credentials, db_connection) -> None:
_test_two_agents_logging(credentials_anthropic_claude_sonnet, db_connection, row_classes=["AnthropicClient"])


@pytest.mark.openai
def test_two_agents_logging(credentials: Credentials, db_connection):
_test_two_agents_logging(credentials, db_connection)
_test_two_agents_logging(credentials, db_connection, row_classes)


def _test_groupchat_logging(credentials: Credentials, credentials2: Credentials, db_connection):
Expand Down Expand Up @@ -255,16 +265,9 @@ def _test_groupchat_logging(credentials: Credentials, credentials2: Credentials,
assert rows[0]["id"] == 1 and rows[0]["version_number"] == 1


@pytest.mark.gemini
def test_groupchat_logging_gemini(credentials_gemini_pro: Credentials, db_connection):
_test_groupchat_logging(credentials_gemini_pro, credentials_gemini_pro, db_connection)


@pytest.mark.anthropic
def test_groupchat_logging_anthropic(credentials_anthropic_claude_sonnet: Credentials, db_connection):
_test_groupchat_logging(credentials_anthropic_claude_sonnet, credentials_anthropic_claude_sonnet, db_connection)


@pytest.mark.openai
def test_groupchat_logging(credentials_gpt_4o: Credentials, credentials_gpt_4o_mini: Credentials, db_connection):
_test_groupchat_logging(credentials_gpt_4o, credentials_gpt_4o_mini, db_connection)
@pytest.mark.parametrize("credentials", credentials_all_llms, indirect=True)
def test_groupchat_logging(
credentials: Credentials,
db_connection: Generator[Optional[sqlite3.Connection], Any, None],
) -> None:
_test_groupchat_logging(credentials, credentials, db_connection)
20 changes: 6 additions & 14 deletions test/agentchat/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from autogen.agentchat import AssistantAgent, UserProxyAgent

from ..conftest import Credentials
from ..conftest import Credentials, credentials_all_llms

here = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -55,19 +55,11 @@ def _test_ai_user_proxy_agent(credentials: Credentials) -> None:
print("Result summary:", res.summary)


@pytest.mark.gemini
def test_ai_user_proxy_agent_gemini(credentials_gemini_pro: Credentials) -> None:
_test_ai_user_proxy_agent(credentials_gemini_pro)


@pytest.mark.anthropic
def test_ai_user_proxy_agent_anthropic(credentials_anthropic_claude_sonnet: Credentials) -> None:
_test_ai_user_proxy_agent(credentials_anthropic_claude_sonnet)


@pytest.mark.openai
def test_ai_user_proxy_agent(credentials_gpt_4o_mini: Credentials) -> None:
_test_ai_user_proxy_agent(credentials_gpt_4o_mini)
@pytest.mark.parametrize("credentials", credentials_all_llms, indirect=True)
def test_ai_user_proxy_agent(
credentials: Credentials,
) -> None:
_test_ai_user_proxy_agent(credentials)


@pytest.mark.openai
Expand Down
44 changes: 11 additions & 33 deletions test/agentchat/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import autogen

from ..conftest import Credentials
from ..conftest import Credentials, credentials_all_llms


def get_market_news(ind, ind_upper):
Expand Down Expand Up @@ -88,22 +88,11 @@ async def _test_async_groupchat(credentials: Credentials):
assert len(user_proxy.chat_messages) > 0


@pytest.mark.openai
@pytest.mark.asyncio
async def test_async_groupchat(credentials_gpt_4o_mini: Credentials):
await _test_async_groupchat(credentials_gpt_4o_mini)


@pytest.mark.gemini
@pytest.mark.asyncio
async def test_async_groupchat_gemini(credentials_gemini_pro: Credentials):
await _test_async_groupchat(credentials_gemini_pro)


@pytest.mark.anthropic
@pytest.mark.asyncio
async def test_async_groupchat_anthropic(credentials_anthropic_claude_sonnet: Credentials):
await _test_async_groupchat(credentials_anthropic_claude_sonnet)
@pytest.mark.parametrize("credentials", credentials_all_llms, indirect=True)
def test_async_groupchat(
credentials: Credentials,
) -> None:
_test_async_groupchat(credentials)


async def _test_stream(credentials: Credentials):
Expand Down Expand Up @@ -172,19 +161,8 @@ async def add_data_reply(recipient, messages, sender, config):
# print("Chat summary and cost:", res.summary, res.cost)


@pytest.mark.openai
@pytest.mark.asyncio
async def test_stream(credentials_gpt_4o_mini: Credentials):
await _test_stream(credentials_gpt_4o_mini)


@pytest.mark.gemini
@pytest.mark.asyncio
async def test_stream_gemini(credentials_gemini_pro: Credentials):
await _test_stream(credentials_gemini_pro)


@pytest.mark.anthropic
@pytest.mark.asyncio
async def test_stream_anthropic(credentials_anthropic_claude_sonnet: Credentials):
await _test_stream(credentials_anthropic_claude_sonnet)
@pytest.mark.parametrize("credentials", credentials_all_llms, indirect=True)
def test_stream(
credentials: Credentials,
) -> None:
_test_stream(credentials)
44 changes: 11 additions & 33 deletions test/agentchat/test_async_get_human_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import autogen

from ..conftest import Credentials
from ..conftest import Credentials, credentials_all_llms


async def _test_async_get_human_input(credentials: Credentials) -> None:
Expand Down Expand Up @@ -40,22 +40,11 @@ async def _test_async_get_human_input(credentials: Credentials) -> None:
print("Human input:", res.human_input)


@pytest.mark.openai
@pytest.mark.asyncio
async def test_async_get_human_input(credentials_gpt_4o_mini: Credentials) -> None:
await _test_async_get_human_input(credentials_gpt_4o_mini)


@pytest.mark.gemini
@pytest.mark.asyncio
async def test_async_get_human_input_gemini(credentials_gemini_pro: Credentials) -> None:
await _test_async_get_human_input(credentials_gemini_pro)


@pytest.mark.anthropic
@pytest.mark.asyncio
async def test_async_get_human_input_anthropic(credentials_anthropic_claude_sonnet: Credentials) -> None:
await _test_async_get_human_input(credentials_anthropic_claude_sonnet)
@pytest.mark.parametrize("credentials", credentials_all_llms, indirect=True)
def test_async_get_human_input(
credentials: Credentials,
) -> None:
_test_async_get_human_input(credentials)


async def _test_async_max_turn(credentials: Credentials):
Expand Down Expand Up @@ -86,19 +75,8 @@ async def _test_async_max_turn(credentials: Credentials):
)


@pytest.mark.openai
@pytest.mark.asyncio
async def test_async_max_turn(credentials_gpt_4o_mini: Credentials):
await _test_async_max_turn(credentials_gpt_4o_mini)


@pytest.mark.gemini
@pytest.mark.asyncio
async def test_async_max_turn_gemini(credentials_gemini_pro: Credentials):
await _test_async_max_turn(credentials_gemini_pro)


@pytest.mark.anthropic
@pytest.mark.asyncio
async def test_async_max_turn_anthropic(credentials_anthropic_claude_sonnet: Credentials):
await _test_async_max_turn(credentials_anthropic_claude_sonnet)
@pytest.mark.parametrize("credentials", credentials_all_llms, indirect=True)
def test_async_max_turn(
credentials: Credentials,
) -> None:
_test_async_max_turn(credentials)
21 changes: 7 additions & 14 deletions test/agentchat/test_chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from autogen import AssistantAgent, GroupChat, GroupChatManager, UserProxyAgent, initiate_chats
from autogen.agentchat.chat import _post_process_carryover_item

from ..conftest import Credentials
from ..conftest import Credentials, credentials_all_llms


@pytest.fixture
Expand Down Expand Up @@ -542,19 +542,12 @@ def currency_calculator(
print(res.summary, res.cost, res.chat_history)


@pytest.mark.openai
def test_chats_w_func(credentials_gpt_4o_mini: Credentials, tasks_work_dir: str):
_test_chats_w_func(credentials_gpt_4o_mini, tasks_work_dir)


@pytest.mark.gemini
def test_chats_w_func_gemini(credentials_gemini_pro: Credentials, tasks_work_dir: str):
_test_chats_w_func(credentials_gemini_pro, tasks_work_dir)


@pytest.mark.anthropic
def test_chats_w_func_anthropic(credentials_anthropic_claude_sonnet: Credentials, tasks_work_dir: str):
_test_chats_w_func(credentials_anthropic_claude_sonnet, tasks_work_dir)
@pytest.mark.parametrize("credentials", credentials_all_llms, indirect=True)
def test_chats_w_func(
credentials: Credentials,
tasks_work_dir: str,
) -> None:
_test_chats_w_func(credentials, tasks_work_dir)


@pytest.mark.openai
Expand Down
23 changes: 6 additions & 17 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from autogen.agentchat.conversable_agent import register_function
from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from ..conftest import Credentials
from ..conftest import Credentials, credentials_all_llms

here = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -1044,22 +1044,11 @@ def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]
stopwatch_mock.assert_called_once_with(num_seconds="2")


@pytest.mark.openai
@pytest.mark.asyncio
async def test_function_registration_e2e_async(credentials_gpt_4o: Credentials) -> None:
await _test_function_registration_e2e_async(credentials_gpt_4o)


@pytest.mark.gemini
@pytest.mark.asyncio
async def test_function_registration_e2e_async_gemini(credentials_gemini_pro: Credentials) -> None:
await _test_function_registration_e2e_async(credentials_gemini_pro)


@pytest.mark.anthropic
@pytest.mark.asyncio
async def test_function_registration_e2e_async_anthropic(credentials_anthropic_claude_sonnet: Credentials) -> None:
await _test_function_registration_e2e_async(credentials_anthropic_claude_sonnet)
@pytest.mark.parametrize("credentials", credentials_all_llms, indirect=True)
def test_function_registration_e2e_async(
credentials: Credentials,
) -> None:
_test_function_registration_e2e_async(credentials)


@pytest.mark.openai
Expand Down
25 changes: 8 additions & 17 deletions test/agentchat/test_dependancy_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from autogen.agentchat import ConversableAgent, UserProxyAgent
from autogen.tools import BaseContext, ChatContext, Depends

from ..conftest import Credentials
from ..conftest import Credentials, credentials_all_llms


class MyContext(BaseContext, BaseModel):
Expand Down Expand Up @@ -234,20 +234,11 @@ def login(user: Annotated[UserContext, Depends(user)]) -> str:
"Login successful.",
)

@pytest.mark.openai
@pytest.mark.parametrize("credentials", credentials_all_llms, indirect=True)
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_end2end(self, credentials_gpt_4o_mini, is_async: bool) -> None:
self._test_end2end(credentials_gpt_4o_mini, is_async)

@pytest.mark.gemini
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_end2end_gemini(self, credentials_gemini_pro, is_async: bool) -> None:
self._test_end2end(credentials_gemini_pro, is_async)

@pytest.mark.anthropic
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_end2end_anthropic(self, credentials_anthropic_claude_sonnet, is_async: bool) -> None:
self._test_end2end(credentials_anthropic_claude_sonnet, is_async)
def test_end2end(
self,
credentials: Credentials,
is_async: bool,
) -> None:
self._test_end2end(credentials, is_async)
16 changes: 16 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,19 @@ def pytest_sessionfinish(session: pytest.Session, exitstatus: int) -> None:
# https://docs.pytest.org/en/stable/reference/exit-codes.html
if exitstatus == 5:
session.exitstatus = 0


credentials_all_llms = [
pytest.param(
credentials_gpt_4o_mini.__name__,
marks=pytest.mark.openai,
),
pytest.param(
credentials_gemini_pro.__name__,
marks=pytest.mark.gemini,
),
pytest.param(
credentials_anthropic_claude_sonnet.__name__,
marks=pytest.mark.anthropic,
),
]

0 comments on commit af1e565

Please sign in to comment.