Skip to content

Commit

Permalink
Merge pull request #1069 from julep-ai/x/model-validation
Browse files Browse the repository at this point in the history
fix(agents-api): add model validation for agent and chat endpoints
  • Loading branch information
Ahmad-mtos authored Jan 20, 2025
2 parents 175b9ae + 828ac95 commit c19b2ad
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 4 deletions.
24 changes: 24 additions & 0 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import wraps
from typing import Literal

import aiohttp
from beartype import beartype
from litellm import acompletion as _acompletion
from litellm import aembedding as _aembedding
Expand Down Expand Up @@ -109,3 +110,26 @@ async def aembedding(
for item in embedding_list
if len(item["embedding"]) >= dimensions
]


@beartype
async def get_model_list(*, custom_api_key: str | None = None) -> list[dict]:
"""
Fetches the list of available models from the LiteLLM server.
Returns:
list[dict]: A list of model information dictionaries
"""

headers = {
"accept": "application/json",
"x-api-key": custom_api_key or litellm_master_key
}

async with aiohttp.ClientSession() as session, session.get(
url=f"{litellm_url}/models" if not custom_api_key else "/models",
headers=headers
) as response:
response.raise_for_status()
data = await response.json()
return data["data"]
6 changes: 5 additions & 1 deletion agents-api/agents_api/routers/agents/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from ...dependencies.developer_id import get_developer_id
from ...queries.agents.create_agent import create_agent as create_agent_query
from ..utils.model_validation import validate_model
from .router import router


Expand All @@ -18,7 +19,10 @@ async def create_agent(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
data: CreateAgentRequest,
) -> ResourceCreatedResponse:
# TODO: Validate model name

if data.model:
await validate_model(data.model)

agent = await create_agent_query(
developer_id=x_developer_id,
data=data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ...queries.agents.create_or_update_agent import (
create_or_update_agent as create_or_update_agent_query,
)
from ..utils.model_validation import validate_model
from .router import router


Expand All @@ -21,7 +22,10 @@ async def create_or_update_agent(
data: CreateOrUpdateAgentRequest,
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
) -> ResourceCreatedResponse:
# TODO: Validate model name

if data.model:
await validate_model(data.model)

agent = await create_or_update_agent_query(
developer_id=x_developer_id,
agent_id=agent_id,
Expand Down
5 changes: 5 additions & 0 deletions agents-api/agents_api/routers/agents/patch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse
from ...dependencies.developer_id import get_developer_id
from ...queries.agents.patch_agent import patch_agent as patch_agent_query
from ..utils.model_validation import validate_model
from .router import router


Expand All @@ -21,6 +22,10 @@ async def patch_agent(
agent_id: UUID,
data: PatchAgentRequest,
) -> ResourceUpdatedResponse:

if data.model:
await validate_model(data.model)

return await patch_agent_query(
agent_id=agent_id,
developer_id=x_developer_id,
Expand Down
7 changes: 6 additions & 1 deletion agents-api/agents_api/routers/agents/update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest
from ...dependencies.developer_id import get_developer_id
from ...queries.agents.update_agent import update_agent as update_agent_query
from ..utils.model_validation import validate_model
from .router import router


Expand All @@ -20,7 +21,11 @@ async def update_agent(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
agent_id: UUID,
data: UpdateAgentRequest,
) -> ResourceUpdatedResponse:
) -> ResourceUpdatedResponse:

if data.model:
await validate_model(data.model)

return await update_agent_query(
developer_id=x_developer_id,
agent_id=agent_id,
Expand Down
5 changes: 5 additions & 0 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ...queries.chat.prepare_chat_context import prepare_chat_context
from ...queries.entries.create_entries import create_entries
from ...queries.sessions.count_sessions import count_sessions as count_sessions_query
from ..utils.model_validation import validate_model
from .metrics import total_tokens_per_user
from .router import router

Expand Down Expand Up @@ -55,6 +56,10 @@ async def chat(
Returns:
ChatResponse: The chat response.
"""

if chat_input.model:
await validate_model(chat_input.model)

# check if the developer is paid
if "paid" not in developer.tags:
# get the session length
Expand Down
19 changes: 19 additions & 0 deletions agents-api/agents_api/routers/utils/model_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from fastapi import HTTPException
from starlette.status import HTTP_400_BAD_REQUEST

from ...clients.litellm import get_model_list


async def validate_model(model_name: str) -> None:
"""
Validates if a given model name is available in LiteLLM.
Raises HTTPException if model is not available.
"""
models = await get_model_list()
available_models = [model["id"] for model in models]

if model_name not in available_models:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail=f"Model {model_name} not available. Available models: {available_models}"
)
11 changes: 10 additions & 1 deletion agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
import string
import sys
from unittest.mock import patch
from uuid import UUID

from agents_api.autogen.openapi_model import (
Expand Down Expand Up @@ -440,10 +441,18 @@ async def test_tool(
return tool


SAMPLE_MODELS = [
{"id": "gpt-4"},
{"id": "gpt-3.5-turbo"},
{"id": "gpt-4o-mini"},
]


@fixture(scope="global")
def client(_dsn=pg_dsn):
with TestClient(app=app) as client:
yield client
with patch("agents_api.routers.utils.model_validation.get_model_list", return_value=SAMPLE_MODELS):
yield client


@fixture(scope="global")
Expand Down
28 changes: 28 additions & 0 deletions agents-api/tests/test_model_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from unittest.mock import patch

from agents_api.routers.utils.model_validation import validate_model
from fastapi import HTTPException
from ward import raises, test

from tests.fixtures import SAMPLE_MODELS


@test("validate_model: succeeds when model is available")
async def _():
# Use async context manager for patching
with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models:
mock_get_models.return_value = SAMPLE_MODELS
await validate_model("gpt-4o-mini")
mock_get_models.assert_called_once()


@test("validate_model: fails when model is unavailable")
async def _():
with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models:
mock_get_models.return_value = SAMPLE_MODELS
with raises(HTTPException) as exc:
await validate_model("non-existent-model")

assert exc.raised.status_code == 400
assert "Model non-existent-model not available" in exc.raised.detail
mock_get_models.assert_called_once()

0 comments on commit c19b2ad

Please sign in to comment.