From 828ac957b71ecc2f7ae7f2d09d2dbf3b8bc7e308 Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Mon, 20 Jan 2025 18:19:45 +0300 Subject: [PATCH] lint --- agents-api/agents_api/clients/litellm.py | 27 +++++++++---------- .../agents_api/routers/agents/create_agent.py | 5 ++-- .../routers/agents/create_or_update_agent.py | 5 ++-- .../agents_api/routers/agents/patch_agent.py | 5 ++-- .../agents_api/routers/agents/update_agent.py | 5 ++-- .../agents_api/routers/sessions/chat.py | 3 ++- .../routers/utils/model_validation.py | 5 ++-- agents-api/tests/fixtures.py | 2 ++ agents-api/tests/test_model_validation.py | 10 ++++--- 9 files changed, 39 insertions(+), 28 deletions(-) diff --git a/agents-api/agents_api/clients/litellm.py b/agents-api/agents_api/clients/litellm.py index d37cca6ac..5eb0e556d 100644 --- a/agents-api/agents_api/clients/litellm.py +++ b/agents-api/agents_api/clients/litellm.py @@ -1,6 +1,7 @@ from functools import wraps from typing import Literal +import aiohttp from beartype import beartype from litellm import acompletion as _acompletion from litellm import aembedding as _aembedding @@ -115,22 +116,20 @@ async def aembedding( async def get_model_list(*, custom_api_key: str | None = None) -> list[dict]: """ Fetches the list of available models from the LiteLLM server. - + Returns: list[dict]: A list of model information dictionaries """ - import aiohttp - + headers = { - 'accept': 'application/json', - 'x-api-key': custom_api_key or litellm_master_key + "accept": "application/json", + "x-api-key": custom_api_key or litellm_master_key } - - async with aiohttp.ClientSession() as session: - async with session.get( - url=f"{litellm_url}/models" if not custom_api_key else "/models", - headers=headers - ) as response: - response.raise_for_status() - data = await response.json() - return data["data"] + + async with aiohttp.ClientSession() as session, session.get( + url=f"{litellm_url}/models" if not custom_api_key else "/models", + headers=headers + ) as response: + response.raise_for_status() + data = await response.json() + return data["data"] diff --git a/agents-api/agents_api/routers/agents/create_agent.py b/agents-api/agents_api/routers/agents/create_agent.py index 15588052e..99476cd87 100644 --- a/agents-api/agents_api/routers/agents/create_agent.py +++ b/agents-api/agents_api/routers/agents/create_agent.py @@ -10,15 +10,16 @@ ) from ...dependencies.developer_id import get_developer_id from ...queries.agents.create_agent import create_agent as create_agent_query -from .router import router from ..utils.model_validation import validate_model +from .router import router + @router.post("/agents", status_code=HTTP_201_CREATED, tags=["agents"]) async def create_agent( x_developer_id: Annotated[UUID, Depends(get_developer_id)], data: CreateAgentRequest, ) -> ResourceCreatedResponse: - + if data.model: await validate_model(data.model) diff --git a/agents-api/agents_api/routers/agents/create_or_update_agent.py b/agents-api/agents_api/routers/agents/create_or_update_agent.py index 9817e63ae..b81bc72ca 100644 --- a/agents-api/agents_api/routers/agents/create_or_update_agent.py +++ b/agents-api/agents_api/routers/agents/create_or_update_agent.py @@ -12,8 +12,9 @@ from ...queries.agents.create_or_update_agent import ( create_or_update_agent as create_or_update_agent_query, ) -from .router import router from ..utils.model_validation import validate_model +from .router import router + @router.post("/agents/{agent_id}", status_code=HTTP_201_CREATED, tags=["agents"]) async def create_or_update_agent( @@ -21,7 +22,7 @@ async def create_or_update_agent( data: CreateOrUpdateAgentRequest, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceCreatedResponse: - + if data.model: await validate_model(data.model) diff --git a/agents-api/agents_api/routers/agents/patch_agent.py b/agents-api/agents_api/routers/agents/patch_agent.py index 94aad35fd..ff30811fb 100644 --- a/agents-api/agents_api/routers/agents/patch_agent.py +++ b/agents-api/agents_api/routers/agents/patch_agent.py @@ -7,8 +7,9 @@ from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse from ...dependencies.developer_id import get_developer_id from ...queries.agents.patch_agent import patch_agent as patch_agent_query -from .router import router from ..utils.model_validation import validate_model +from .router import router + @router.patch( "/agents/{agent_id}", @@ -21,7 +22,7 @@ async def patch_agent( agent_id: UUID, data: PatchAgentRequest, ) -> ResourceUpdatedResponse: - + if data.model: await validate_model(data.model) diff --git a/agents-api/agents_api/routers/agents/update_agent.py b/agents-api/agents_api/routers/agents/update_agent.py index 34d431927..a2732a51a 100644 --- a/agents-api/agents_api/routers/agents/update_agent.py +++ b/agents-api/agents_api/routers/agents/update_agent.py @@ -7,8 +7,9 @@ from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest from ...dependencies.developer_id import get_developer_id from ...queries.agents.update_agent import update_agent as update_agent_query -from .router import router from ..utils.model_validation import validate_model +from .router import router + @router.put( "/agents/{agent_id}", @@ -21,7 +22,7 @@ async def update_agent( agent_id: UUID, data: UpdateAgentRequest, ) -> ResourceUpdatedResponse: - + if data.model: await validate_model(data.model) diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index 74b22a91c..2a58653a5 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -23,9 +23,10 @@ from ...queries.chat.prepare_chat_context import prepare_chat_context from ...queries.entries.create_entries import create_entries from ...queries.sessions.count_sessions import count_sessions as count_sessions_query +from ..utils.model_validation import validate_model from .metrics import total_tokens_per_user from .router import router -from ..utils.model_validation import validate_model + COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" diff --git a/agents-api/agents_api/routers/utils/model_validation.py b/agents-api/agents_api/routers/utils/model_validation.py index 31e93926d..6f99f49f4 100644 --- a/agents-api/agents_api/routers/utils/model_validation.py +++ b/agents-api/agents_api/routers/utils/model_validation.py @@ -3,6 +3,7 @@ from ...clients.litellm import get_model_list + async def validate_model(model_name: str) -> None: """ Validates if a given model name is available in LiteLLM. @@ -10,9 +11,9 @@ async def validate_model(model_name: str) -> None: """ models = await get_model_list() available_models = [model["id"] for model in models] - + if model_name not in available_models: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, detail=f"Model {model_name} not available. Available models: {available_models}" - ) \ No newline at end of file + ) diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index e98ddf334..e052df859 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -440,12 +440,14 @@ async def test_tool( ) return tool + SAMPLE_MODELS = [ {"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}, {"id": "gpt-4o-mini"}, ] + @fixture(scope="global") def client(_dsn=pg_dsn): with TestClient(app=app) as client: diff --git a/agents-api/tests/test_model_validation.py b/agents-api/tests/test_model_validation.py index 4690b9d28..40dfce359 100644 --- a/agents-api/tests/test_model_validation.py +++ b/agents-api/tests/test_model_validation.py @@ -1,9 +1,12 @@ from unittest.mock import patch + from agents_api.routers.utils.model_validation import validate_model -from ward import test, raises from fastapi import HTTPException +from ward import raises, test + from tests.fixtures import SAMPLE_MODELS + @test("validate_model: succeeds when model is available") async def _(): # Use async context manager for patching @@ -12,13 +15,14 @@ async def _(): await validate_model("gpt-4o-mini") mock_get_models.assert_called_once() + @test("validate_model: fails when model is unavailable") async def _(): with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models: mock_get_models.return_value = SAMPLE_MODELS with raises(HTTPException) as exc: await validate_model("non-existent-model") - + assert exc.raised.status_code == 400 assert "Model non-existent-model not available" in exc.raised.detail - mock_get_models.assert_called_once() \ No newline at end of file + mock_get_models.assert_called_once()