Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmad-mtos committed Jan 20, 2025
1 parent 122dcff commit 828ac95
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 28 deletions.
27 changes: 13 additions & 14 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import wraps
from typing import Literal

import aiohttp
from beartype import beartype
from litellm import acompletion as _acompletion
from litellm import aembedding as _aembedding
Expand Down Expand Up @@ -115,22 +116,20 @@ async def aembedding(
async def get_model_list(*, custom_api_key: str | None = None) -> list[dict]:
"""
Fetches the list of available models from the LiteLLM server.
Returns:
list[dict]: A list of model information dictionaries
"""
import aiohttp


headers = {
'accept': 'application/json',
'x-api-key': custom_api_key or litellm_master_key
"accept": "application/json",
"x-api-key": custom_api_key or litellm_master_key
}

async with aiohttp.ClientSession() as session:
async with session.get(
url=f"{litellm_url}/models" if not custom_api_key else "/models",
headers=headers
) as response:
response.raise_for_status()
data = await response.json()
return data["data"]

async with aiohttp.ClientSession() as session, session.get(
url=f"{litellm_url}/models" if not custom_api_key else "/models",
headers=headers
) as response:
response.raise_for_status()
data = await response.json()
return data["data"]
5 changes: 3 additions & 2 deletions agents-api/agents_api/routers/agents/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
)
from ...dependencies.developer_id import get_developer_id
from ...queries.agents.create_agent import create_agent as create_agent_query
from .router import router
from ..utils.model_validation import validate_model
from .router import router


@router.post("/agents", status_code=HTTP_201_CREATED, tags=["agents"])
async def create_agent(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
data: CreateAgentRequest,
) -> ResourceCreatedResponse:

if data.model:
await validate_model(data.model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
from ...queries.agents.create_or_update_agent import (
create_or_update_agent as create_or_update_agent_query,
)
from .router import router
from ..utils.model_validation import validate_model
from .router import router


@router.post("/agents/{agent_id}", status_code=HTTP_201_CREATED, tags=["agents"])
async def create_or_update_agent(
agent_id: UUID,
data: CreateOrUpdateAgentRequest,
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
) -> ResourceCreatedResponse:

if data.model:
await validate_model(data.model)

Expand Down
5 changes: 3 additions & 2 deletions agents-api/agents_api/routers/agents/patch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse
from ...dependencies.developer_id import get_developer_id
from ...queries.agents.patch_agent import patch_agent as patch_agent_query
from .router import router
from ..utils.model_validation import validate_model
from .router import router


@router.patch(
"/agents/{agent_id}",
Expand All @@ -21,7 +22,7 @@ async def patch_agent(
agent_id: UUID,
data: PatchAgentRequest,
) -> ResourceUpdatedResponse:

if data.model:
await validate_model(data.model)

Expand Down
5 changes: 3 additions & 2 deletions agents-api/agents_api/routers/agents/update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest
from ...dependencies.developer_id import get_developer_id
from ...queries.agents.update_agent import update_agent as update_agent_query
from .router import router
from ..utils.model_validation import validate_model
from .router import router


@router.put(
"/agents/{agent_id}",
Expand All @@ -21,7 +22,7 @@ async def update_agent(
agent_id: UUID,
data: UpdateAgentRequest,
) -> ResourceUpdatedResponse:

if data.model:
await validate_model(data.model)

Expand Down
3 changes: 2 additions & 1 deletion agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from ...queries.chat.prepare_chat_context import prepare_chat_context
from ...queries.entries.create_entries import create_entries
from ...queries.sessions.count_sessions import count_sessions as count_sessions_query
from ..utils.model_validation import validate_model
from .metrics import total_tokens_per_user
from .router import router
from ..utils.model_validation import validate_model

COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"


Expand Down
5 changes: 3 additions & 2 deletions agents-api/agents_api/routers/utils/model_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

from ...clients.litellm import get_model_list


async def validate_model(model_name: str) -> None:
"""
Validates if a given model name is available in LiteLLM.
Raises HTTPException if model is not available.
"""
models = await get_model_list()
available_models = [model["id"] for model in models]

if model_name not in available_models:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail=f"Model {model_name} not available. Available models: {available_models}"
)
)
2 changes: 2 additions & 0 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,14 @@ async def test_tool(
)
return tool


SAMPLE_MODELS = [
{"id": "gpt-4"},
{"id": "gpt-3.5-turbo"},
{"id": "gpt-4o-mini"},
]


@fixture(scope="global")
def client(_dsn=pg_dsn):
with TestClient(app=app) as client:
Expand Down
10 changes: 7 additions & 3 deletions agents-api/tests/test_model_validation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from unittest.mock import patch

from agents_api.routers.utils.model_validation import validate_model
from ward import test, raises
from fastapi import HTTPException
from ward import raises, test

from tests.fixtures import SAMPLE_MODELS


@test("validate_model: succeeds when model is available")
async def _():
# Use async context manager for patching
Expand All @@ -12,13 +15,14 @@ async def _():
await validate_model("gpt-4o-mini")
mock_get_models.assert_called_once()


@test("validate_model: fails when model is unavailable")
async def _():
with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models:
mock_get_models.return_value = SAMPLE_MODELS
with raises(HTTPException) as exc:
await validate_model("non-existent-model")

assert exc.raised.status_code == 400
assert "Model non-existent-model not available" in exc.raised.detail
mock_get_models.assert_called_once()
mock_get_models.assert_called_once()

0 comments on commit 828ac95

Please sign in to comment.