From abfb732123da089ea861292222a9c84be09011ee Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Mon, 20 Jan 2025 17:23:22 +0300 Subject: [PATCH 1/3] fix(agents-api): add model validation for agent and chat endpoints --- agents-api/agents_api/clients/litellm.py | 25 +++++++++++++++++++ .../agents_api/routers/agents/create_agent.py | 7 ++++-- .../routers/agents/create_or_update_agent.py | 7 ++++-- .../agents_api/routers/agents/patch_agent.py | 6 ++++- .../agents_api/routers/agents/update_agent.py | 8 ++++-- .../agents_api/routers/sessions/chat.py | 6 ++++- .../routers/utils/model_validation.py | 18 +++++++++++++ 7 files changed, 69 insertions(+), 8 deletions(-) create mode 100644 agents-api/agents_api/routers/utils/model_validation.py diff --git a/agents-api/agents_api/clients/litellm.py b/agents-api/agents_api/clients/litellm.py index 904811a6c..d37cca6ac 100644 --- a/agents-api/agents_api/clients/litellm.py +++ b/agents-api/agents_api/clients/litellm.py @@ -109,3 +109,28 @@ async def aembedding( for item in embedding_list if len(item["embedding"]) >= dimensions ] + + +@beartype +async def get_model_list(*, custom_api_key: str | None = None) -> list[dict]: + """ + Fetches the list of available models from the LiteLLM server. + + Returns: + list[dict]: A list of model information dictionaries + """ + import aiohttp + + headers = { + 'accept': 'application/json', + 'x-api-key': custom_api_key or litellm_master_key + } + + async with aiohttp.ClientSession() as session: + async with session.get( + url=f"{litellm_url}/models" if not custom_api_key else "/models", + headers=headers + ) as response: + response.raise_for_status() + data = await response.json() + return data["data"] diff --git a/agents-api/agents_api/routers/agents/create_agent.py b/agents-api/agents_api/routers/agents/create_agent.py index f630d5251..15588052e 100644 --- a/agents-api/agents_api/routers/agents/create_agent.py +++ b/agents-api/agents_api/routers/agents/create_agent.py @@ -11,14 +11,17 @@ from ...dependencies.developer_id import get_developer_id from ...queries.agents.create_agent import create_agent as create_agent_query from .router import router - +from ..utils.model_validation import validate_model @router.post("/agents", status_code=HTTP_201_CREATED, tags=["agents"]) async def create_agent( x_developer_id: Annotated[UUID, Depends(get_developer_id)], data: CreateAgentRequest, ) -> ResourceCreatedResponse: - # TODO: Validate model name + + if data.model: + await validate_model(data.model) + agent = await create_agent_query( developer_id=x_developer_id, data=data, diff --git a/agents-api/agents_api/routers/agents/create_or_update_agent.py b/agents-api/agents_api/routers/agents/create_or_update_agent.py index fd2fc124c..9817e63ae 100644 --- a/agents-api/agents_api/routers/agents/create_or_update_agent.py +++ b/agents-api/agents_api/routers/agents/create_or_update_agent.py @@ -13,7 +13,7 @@ create_or_update_agent as create_or_update_agent_query, ) from .router import router - +from ..utils.model_validation import validate_model @router.post("/agents/{agent_id}", status_code=HTTP_201_CREATED, tags=["agents"]) async def create_or_update_agent( @@ -21,7 +21,10 @@ async def create_or_update_agent( data: CreateOrUpdateAgentRequest, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceCreatedResponse: - # TODO: Validate model name + + if data.model: + await validate_model(data.model) + agent = await create_or_update_agent_query( developer_id=x_developer_id, agent_id=agent_id, diff --git a/agents-api/agents_api/routers/agents/patch_agent.py b/agents-api/agents_api/routers/agents/patch_agent.py index bb7c16d5c..94aad35fd 100644 --- a/agents-api/agents_api/routers/agents/patch_agent.py +++ b/agents-api/agents_api/routers/agents/patch_agent.py @@ -8,7 +8,7 @@ from ...dependencies.developer_id import get_developer_id from ...queries.agents.patch_agent import patch_agent as patch_agent_query from .router import router - +from ..utils.model_validation import validate_model @router.patch( "/agents/{agent_id}", @@ -21,6 +21,10 @@ async def patch_agent( agent_id: UUID, data: PatchAgentRequest, ) -> ResourceUpdatedResponse: + + if data.model: + await validate_model(data.model) + return await patch_agent_query( agent_id=agent_id, developer_id=x_developer_id, diff --git a/agents-api/agents_api/routers/agents/update_agent.py b/agents-api/agents_api/routers/agents/update_agent.py index 608da0b20..34d431927 100644 --- a/agents-api/agents_api/routers/agents/update_agent.py +++ b/agents-api/agents_api/routers/agents/update_agent.py @@ -8,7 +8,7 @@ from ...dependencies.developer_id import get_developer_id from ...queries.agents.update_agent import update_agent as update_agent_query from .router import router - +from ..utils.model_validation import validate_model @router.put( "/agents/{agent_id}", @@ -20,7 +20,11 @@ async def update_agent( x_developer_id: Annotated[UUID, Depends(get_developer_id)], agent_id: UUID, data: UpdateAgentRequest, -) -> ResourceUpdatedResponse: + ) -> ResourceUpdatedResponse: + + if data.model: + await validate_model(data.model) + return await update_agent_query( developer_id=x_developer_id, agent_id=agent_id, diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index 3a3ce5e32..74b22a91c 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -25,7 +25,7 @@ from ...queries.sessions.count_sessions import count_sessions as count_sessions_query from .metrics import total_tokens_per_user from .router import router - +from ..utils.model_validation import validate_model COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" @@ -55,6 +55,10 @@ async def chat( Returns: ChatResponse: The chat response. """ + + if chat_input.model: + await validate_model(chat_input.model) + # check if the developer is paid if "paid" not in developer.tags: # get the session length diff --git a/agents-api/agents_api/routers/utils/model_validation.py b/agents-api/agents_api/routers/utils/model_validation.py new file mode 100644 index 000000000..31e93926d --- /dev/null +++ b/agents-api/agents_api/routers/utils/model_validation.py @@ -0,0 +1,18 @@ +from fastapi import HTTPException +from starlette.status import HTTP_400_BAD_REQUEST + +from ...clients.litellm import get_model_list + +async def validate_model(model_name: str) -> None: + """ + Validates if a given model name is available in LiteLLM. + Raises HTTPException if model is not available. + """ + models = await get_model_list() + available_models = [model["id"] for model in models] + + if model_name not in available_models: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Model {model_name} not available. Available models: {available_models}" + ) \ No newline at end of file From 122dcfffaf8f0c53e7b8bd599b437c7cf9f42a6a Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Mon, 20 Jan 2025 18:17:49 +0300 Subject: [PATCH 2/3] chore(agents-api): add model validation tests --- agents-api/tests/fixtures.py | 9 ++++++++- agents-api/tests/test_model_validation.py | 24 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 agents-api/tests/test_model_validation.py diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 5b0ff68cc..e98ddf334 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -2,6 +2,7 @@ import random import string import sys +from unittest.mock import patch from uuid import UUID from agents_api.autogen.openapi_model import ( @@ -439,11 +440,17 @@ async def test_tool( ) return tool +SAMPLE_MODELS = [ + {"id": "gpt-4"}, + {"id": "gpt-3.5-turbo"}, + {"id": "gpt-4o-mini"}, +] @fixture(scope="global") def client(_dsn=pg_dsn): with TestClient(app=app) as client: - yield client + with patch("agents_api.routers.utils.model_validation.get_model_list", return_value=SAMPLE_MODELS): + yield client @fixture(scope="global") diff --git a/agents-api/tests/test_model_validation.py b/agents-api/tests/test_model_validation.py new file mode 100644 index 000000000..4690b9d28 --- /dev/null +++ b/agents-api/tests/test_model_validation.py @@ -0,0 +1,24 @@ +from unittest.mock import patch +from agents_api.routers.utils.model_validation import validate_model +from ward import test, raises +from fastapi import HTTPException +from tests.fixtures import SAMPLE_MODELS + +@test("validate_model: succeeds when model is available") +async def _(): + # Use async context manager for patching + with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models: + mock_get_models.return_value = SAMPLE_MODELS + await validate_model("gpt-4o-mini") + mock_get_models.assert_called_once() + +@test("validate_model: fails when model is unavailable") +async def _(): + with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models: + mock_get_models.return_value = SAMPLE_MODELS + with raises(HTTPException) as exc: + await validate_model("non-existent-model") + + assert exc.raised.status_code == 400 + assert "Model non-existent-model not available" in exc.raised.detail + mock_get_models.assert_called_once() \ No newline at end of file From 828ac957b71ecc2f7ae7f2d09d2dbf3b8bc7e308 Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Mon, 20 Jan 2025 18:19:45 +0300 Subject: [PATCH 3/3] lint --- agents-api/agents_api/clients/litellm.py | 27 +++++++++---------- .../agents_api/routers/agents/create_agent.py | 5 ++-- .../routers/agents/create_or_update_agent.py | 5 ++-- .../agents_api/routers/agents/patch_agent.py | 5 ++-- .../agents_api/routers/agents/update_agent.py | 5 ++-- .../agents_api/routers/sessions/chat.py | 3 ++- .../routers/utils/model_validation.py | 5 ++-- agents-api/tests/fixtures.py | 2 ++ agents-api/tests/test_model_validation.py | 10 ++++--- 9 files changed, 39 insertions(+), 28 deletions(-) diff --git a/agents-api/agents_api/clients/litellm.py b/agents-api/agents_api/clients/litellm.py index d37cca6ac..5eb0e556d 100644 --- a/agents-api/agents_api/clients/litellm.py +++ b/agents-api/agents_api/clients/litellm.py @@ -1,6 +1,7 @@ from functools import wraps from typing import Literal +import aiohttp from beartype import beartype from litellm import acompletion as _acompletion from litellm import aembedding as _aembedding @@ -115,22 +116,20 @@ async def aembedding( async def get_model_list(*, custom_api_key: str | None = None) -> list[dict]: """ Fetches the list of available models from the LiteLLM server. - + Returns: list[dict]: A list of model information dictionaries """ - import aiohttp - + headers = { - 'accept': 'application/json', - 'x-api-key': custom_api_key or litellm_master_key + "accept": "application/json", + "x-api-key": custom_api_key or litellm_master_key } - - async with aiohttp.ClientSession() as session: - async with session.get( - url=f"{litellm_url}/models" if not custom_api_key else "/models", - headers=headers - ) as response: - response.raise_for_status() - data = await response.json() - return data["data"] + + async with aiohttp.ClientSession() as session, session.get( + url=f"{litellm_url}/models" if not custom_api_key else "/models", + headers=headers + ) as response: + response.raise_for_status() + data = await response.json() + return data["data"] diff --git a/agents-api/agents_api/routers/agents/create_agent.py b/agents-api/agents_api/routers/agents/create_agent.py index 15588052e..99476cd87 100644 --- a/agents-api/agents_api/routers/agents/create_agent.py +++ b/agents-api/agents_api/routers/agents/create_agent.py @@ -10,15 +10,16 @@ ) from ...dependencies.developer_id import get_developer_id from ...queries.agents.create_agent import create_agent as create_agent_query -from .router import router from ..utils.model_validation import validate_model +from .router import router + @router.post("/agents", status_code=HTTP_201_CREATED, tags=["agents"]) async def create_agent( x_developer_id: Annotated[UUID, Depends(get_developer_id)], data: CreateAgentRequest, ) -> ResourceCreatedResponse: - + if data.model: await validate_model(data.model) diff --git a/agents-api/agents_api/routers/agents/create_or_update_agent.py b/agents-api/agents_api/routers/agents/create_or_update_agent.py index 9817e63ae..b81bc72ca 100644 --- a/agents-api/agents_api/routers/agents/create_or_update_agent.py +++ b/agents-api/agents_api/routers/agents/create_or_update_agent.py @@ -12,8 +12,9 @@ from ...queries.agents.create_or_update_agent import ( create_or_update_agent as create_or_update_agent_query, ) -from .router import router from ..utils.model_validation import validate_model +from .router import router + @router.post("/agents/{agent_id}", status_code=HTTP_201_CREATED, tags=["agents"]) async def create_or_update_agent( @@ -21,7 +22,7 @@ async def create_or_update_agent( data: CreateOrUpdateAgentRequest, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceCreatedResponse: - + if data.model: await validate_model(data.model) diff --git a/agents-api/agents_api/routers/agents/patch_agent.py b/agents-api/agents_api/routers/agents/patch_agent.py index 94aad35fd..ff30811fb 100644 --- a/agents-api/agents_api/routers/agents/patch_agent.py +++ b/agents-api/agents_api/routers/agents/patch_agent.py @@ -7,8 +7,9 @@ from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse from ...dependencies.developer_id import get_developer_id from ...queries.agents.patch_agent import patch_agent as patch_agent_query -from .router import router from ..utils.model_validation import validate_model +from .router import router + @router.patch( "/agents/{agent_id}", @@ -21,7 +22,7 @@ async def patch_agent( agent_id: UUID, data: PatchAgentRequest, ) -> ResourceUpdatedResponse: - + if data.model: await validate_model(data.model) diff --git a/agents-api/agents_api/routers/agents/update_agent.py b/agents-api/agents_api/routers/agents/update_agent.py index 34d431927..a2732a51a 100644 --- a/agents-api/agents_api/routers/agents/update_agent.py +++ b/agents-api/agents_api/routers/agents/update_agent.py @@ -7,8 +7,9 @@ from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest from ...dependencies.developer_id import get_developer_id from ...queries.agents.update_agent import update_agent as update_agent_query -from .router import router from ..utils.model_validation import validate_model +from .router import router + @router.put( "/agents/{agent_id}", @@ -21,7 +22,7 @@ async def update_agent( agent_id: UUID, data: UpdateAgentRequest, ) -> ResourceUpdatedResponse: - + if data.model: await validate_model(data.model) diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index 74b22a91c..2a58653a5 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -23,9 +23,10 @@ from ...queries.chat.prepare_chat_context import prepare_chat_context from ...queries.entries.create_entries import create_entries from ...queries.sessions.count_sessions import count_sessions as count_sessions_query +from ..utils.model_validation import validate_model from .metrics import total_tokens_per_user from .router import router -from ..utils.model_validation import validate_model + COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" diff --git a/agents-api/agents_api/routers/utils/model_validation.py b/agents-api/agents_api/routers/utils/model_validation.py index 31e93926d..6f99f49f4 100644 --- a/agents-api/agents_api/routers/utils/model_validation.py +++ b/agents-api/agents_api/routers/utils/model_validation.py @@ -3,6 +3,7 @@ from ...clients.litellm import get_model_list + async def validate_model(model_name: str) -> None: """ Validates if a given model name is available in LiteLLM. @@ -10,9 +11,9 @@ async def validate_model(model_name: str) -> None: """ models = await get_model_list() available_models = [model["id"] for model in models] - + if model_name not in available_models: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, detail=f"Model {model_name} not available. Available models: {available_models}" - ) \ No newline at end of file + ) diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index e98ddf334..e052df859 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -440,12 +440,14 @@ async def test_tool( ) return tool + SAMPLE_MODELS = [ {"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}, {"id": "gpt-4o-mini"}, ] + @fixture(scope="global") def client(_dsn=pg_dsn): with TestClient(app=app) as client: diff --git a/agents-api/tests/test_model_validation.py b/agents-api/tests/test_model_validation.py index 4690b9d28..40dfce359 100644 --- a/agents-api/tests/test_model_validation.py +++ b/agents-api/tests/test_model_validation.py @@ -1,9 +1,12 @@ from unittest.mock import patch + from agents_api.routers.utils.model_validation import validate_model -from ward import test, raises from fastapi import HTTPException +from ward import raises, test + from tests.fixtures import SAMPLE_MODELS + @test("validate_model: succeeds when model is available") async def _(): # Use async context manager for patching @@ -12,13 +15,14 @@ async def _(): await validate_model("gpt-4o-mini") mock_get_models.assert_called_once() + @test("validate_model: fails when model is unavailable") async def _(): with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models: mock_get_models.return_value = SAMPLE_MODELS with raises(HTTPException) as exc: await validate_model("non-existent-model") - + assert exc.raised.status_code == 400 assert "Model non-existent-model not available" in exc.raised.detail - mock_get_models.assert_called_once() \ No newline at end of file + mock_get_models.assert_called_once()