-
Notifications
You must be signed in to change notification settings - Fork 45.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added listing, sorting, filtering and ordering of agents
- Loading branch information
Showing
7 changed files
with
641 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import prisma.models | ||
import prisma.types | ||
from prisma.errors import PrismaError | ||
from fuzzywuzzy import fuzz | ||
|
||
|
||
class AgentQueryError(Exception): | ||
"""Custom exception for agent query errors""" | ||
|
||
pass | ||
|
||
|
||
async def get_agents( | ||
page: int = 1, | ||
page_size: int = 10, | ||
name: str | None = None, | ||
keyword: str | None = None, | ||
category: str | None = None, | ||
description: str | None = None, | ||
description_threshold: int = 60, | ||
sort_by: str = "createdAt", | ||
sort_order: str = "desc", | ||
): | ||
""" | ||
Retrieve a list of agents from the database based on the provided filters and pagination parameters. | ||
Args: | ||
page (int, optional): The page number to retrieve. Defaults to 1. | ||
page_size (int, optional): The number of agents per page. Defaults to 10. | ||
name (str, optional): Filter agents by name. Defaults to None. | ||
keyword (str, optional): Filter agents by keyword. Defaults to None. | ||
category (str, optional): Filter agents by category. Defaults to None. | ||
description (str, optional): Filter agents by description. Defaults to None. | ||
description_threshold (int, optional): The minimum fuzzy search threshold for the description. Defaults to 60. | ||
sort_by (str, optional): The field to sort the agents by. Defaults to "createdAt". | ||
sort_order (str, optional): The sort order ("asc" or "desc"). Defaults to "desc". | ||
Returns: | ||
dict: A dictionary containing the list of agents, total count, current page number, page size, and total number of pages. | ||
""" | ||
try: | ||
# Define the base query | ||
query = {} | ||
|
||
# Add optional filters | ||
if name: | ||
query["name"] = {"contains": name, "mode": "insensitive"} | ||
if keyword: | ||
query["keywords"] = {"has": keyword} | ||
if category: | ||
query["categories"] = {"has": category} | ||
|
||
# Define sorting | ||
order = {sort_by: sort_order} | ||
|
||
# Calculate pagination | ||
skip = (page - 1) * page_size | ||
|
||
# Execute the query | ||
try: | ||
agents = await prisma.models.Agents.prisma().find_many( | ||
where=query, # type: ignore | ||
order=order, # type: ignore | ||
skip=skip, | ||
take=page_size, | ||
) | ||
except PrismaError as e: | ||
raise AgentQueryError(f"Database query failed: {str(e)}") | ||
|
||
# Apply fuzzy search on description if provided | ||
if description: | ||
try: | ||
filtered_agents = [] | ||
for agent in agents: | ||
if ( | ||
agent.description | ||
and fuzz.partial_ratio( | ||
description.lower(), agent.description.lower() | ||
) | ||
>= description_threshold | ||
): | ||
filtered_agents.append(agent) | ||
agents = filtered_agents | ||
except AttributeError as e: | ||
raise AgentQueryError(f"Error during fuzzy search: {str(e)}") | ||
|
||
# Get total count for pagination info | ||
total_count = len(agents) | ||
|
||
return { | ||
"agents": agents, | ||
"total_count": total_count, | ||
"page": page, | ||
"page_size": page_size, | ||
"total_pages": (total_count + page_size - 1) // page_size, | ||
} | ||
|
||
except AgentQueryError as e: | ||
# Log the error or handle it as needed | ||
raise e | ||
except ValueError as e: | ||
raise AgentQueryError(f"Invalid input parameter: {str(e)}") | ||
except Exception as e: | ||
# Catch any other unexpected exceptions | ||
raise AgentQueryError(f"Unexpected error occurred: {str(e)}") | ||
|
||
|
||
async def get_agent_details(agent_id: str, version: int | None = None): | ||
""" | ||
Retrieve agent details from the database. | ||
Args: | ||
agent_id (str): The ID of the agent. | ||
version (int | None, optional): The version of the agent. Defaults to None. | ||
Returns: | ||
dict: The agent details. | ||
Raises: | ||
AgentQueryError: If the agent is not found or if there is an error querying the database. | ||
""" | ||
try: | ||
query = {"id": agent_id} | ||
if version is not None: | ||
query["version"] = version # type: ignore | ||
|
||
agent = await prisma.models.Agents.prisma().find_first(where=query) # type: ignore | ||
|
||
if not agent: | ||
raise AgentQueryError("Agent not found") | ||
|
||
return agent | ||
|
||
except PrismaError as e: | ||
raise AgentQueryError(f"Database query failed: {str(e)}") | ||
except Exception as e: | ||
raise AgentQueryError(f"Unexpected error occurred: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import pydantic | ||
import typing | ||
|
||
|
||
class AgentResponse(pydantic.BaseModel): | ||
""" | ||
Represents a response from an agent. | ||
Attributes: | ||
id (str): The ID of the agent. | ||
name (str, optional): The name of the agent. | ||
description (str, optional): The description of the agent. | ||
author (str, optional): The author of the agent. | ||
keywords (list[str]): The keywords associated with the agent. | ||
categories (list[str]): The categories the agent belongs to. | ||
version (int): The version of the agent. | ||
createdAt (str): The creation date of the agent. | ||
updatedAt (str): The last update date of the agent. | ||
""" | ||
|
||
id: str | ||
name: typing.Optional[str] | ||
description: typing.Optional[str] | ||
author: typing.Optional[str] | ||
keywords: list[str] | ||
categories: list[str] | ||
version: int | ||
createdAt: str | ||
updatedAt: str | ||
|
||
|
||
class AgentListResponse(pydantic.BaseModel): | ||
""" | ||
Represents a response containing a list of agents. | ||
Attributes: | ||
agents (list[AgentResponse]): The list of agents. | ||
total_count (int): The total count of agents. | ||
page (int): The current page number. | ||
page_size (int): The number of agents per page. | ||
total_pages (int): The total number of pages. | ||
""" | ||
|
||
agents: list[AgentResponse] | ||
total_count: int | ||
page: int | ||
page_size: int | ||
total_pages: int | ||
|
||
|
||
class AgentDetailResponse(pydantic.BaseModel): | ||
""" | ||
Represents the response data for an agent detail. | ||
Attributes: | ||
id (str): The ID of the agent. | ||
name (Optional[str]): The name of the agent. | ||
description (Optional[str]): The description of the agent. | ||
author (Optional[str]): The author of the agent. | ||
keywords (List[str]): The keywords associated with the agent. | ||
categories (List[str]): The categories the agent belongs to. | ||
version (int): The version of the agent. | ||
createdAt (str): The creation date of the agent. | ||
updatedAt (str): The last update date of the agent. | ||
graph (Dict[str, Any]): The graph data of the agent. | ||
""" | ||
|
||
id: str | ||
name: typing.Optional[str] | ||
description: typing.Optional[str] | ||
author: typing.Optional[str] | ||
keywords: list[str] | ||
categories: list[str] | ||
version: int | ||
createdAt: str | ||
updatedAt: str | ||
graph: dict[str, typing.Any] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,102 @@ | ||
from fastapi import APIRouter | ||
from fastapi import APIRouter, HTTPException, Query, Path | ||
from typing import Optional | ||
from market.db import get_agents, get_agent_details, AgentQueryError | ||
import market.model | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.get("/agents", response_model=market.model.AgentListResponse) | ||
async def list_agents( | ||
page: int = Query(1, ge=1, description="Page number"), | ||
page_size: int = Query(10, ge=1, le=100, description="Number of items per page"), | ||
name: Optional[str] = Query(None, description="Filter by agent name"), | ||
keyword: Optional[str] = Query(None, description="Filter by keyword"), | ||
category: Optional[str] = Query(None, description="Filter by category"), | ||
description: Optional[str] = Query(None, description="Fuzzy search in description"), | ||
description_threshold: int = Query( | ||
60, ge=0, le=100, description="Fuzzy search threshold" | ||
), | ||
sort_by: str = Query("createdAt", description="Field to sort by"), | ||
sort_order: str = Query("desc", description="Sort order (asc or desc)"), | ||
): | ||
""" | ||
Retrieve a list of agents based on the provided filters. | ||
Args: | ||
page (int): Page number (default: 1). | ||
page_size (int): Number of items per page (default: 10, min: 1, max: 100). | ||
name (str, optional): Filter by agent name. | ||
keyword (str, optional): Filter by keyword. | ||
category (str, optional): Filter by category. | ||
description (str, optional): Fuzzy search in description. | ||
description_threshold (int): Fuzzy search threshold (default: 60, min: 0, max: 100). | ||
sort_by (str): Field to sort by (default: "createdAt"). | ||
sort_order (str): Sort order (asc or desc) (default: "desc"). | ||
Returns: | ||
market.model.AgentListResponse: A response containing the list of agents and pagination information. | ||
Raises: | ||
HTTPException: If there is a client error (status code 400) or an unexpected error (status code 500). | ||
""" | ||
try: | ||
result = await get_agents( | ||
page=page, | ||
page_size=page_size, | ||
name=name, | ||
keyword=keyword, | ||
category=category, | ||
description=description, | ||
description_threshold=description_threshold, | ||
sort_by=sort_by, | ||
sort_order=sort_order, | ||
) | ||
|
||
# Convert the result to the response model | ||
agents = [ | ||
market.model.AgentResponse(**agent.dict()) for agent in result["agents"] | ||
] | ||
|
||
return market.model.AgentListResponse( | ||
agents=agents, | ||
total_count=result["total_count"], | ||
page=result["page"], | ||
page_size=result["page_size"], | ||
total_pages=result["total_pages"], | ||
) | ||
|
||
except AgentQueryError as e: | ||
raise HTTPException(status_code=400, detail=str(e)) | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail="An unexpected error occurred") | ||
|
||
|
||
@router.get("/agents/{agent_id}", response_model=market.model.AgentDetailResponse) | ||
async def get_agent_details_endpoint( | ||
agent_id: str = Path(..., description="The ID of the agent to retrieve"), | ||
version: Optional[int] = Query(None, description="Specific version of the agent"), | ||
): | ||
""" | ||
Retrieve details of a specific agent. | ||
Args: | ||
agent_id (str): The ID of the agent to retrieve. | ||
version (Optional[int]): Specific version of the agent (default: None). | ||
Returns: | ||
market.model.AgentDetailResponse: The response containing the agent details. | ||
Raises: | ||
HTTPException: If the agent is not found or an unexpected error occurs. | ||
""" | ||
try: | ||
agent = await get_agent_details(agent_id, version) | ||
return market.model.AgentDetailResponse(**agent.model_dump()) | ||
|
||
except AgentQueryError as e: | ||
raise HTTPException(status_code=404, detail=str(e)) | ||
except Exception as e: | ||
raise HTTPException( | ||
status_code=500, detail=f"An unexpected error occurred: {str(e)}" | ||
) |
Oops, something went wrong.