Skip to content

Commit

Permalink
Added listing, sorting, filtering and ordering of agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Swiftyos committed Jul 31, 2024
1 parent 122f544 commit d1187b5
Show file tree
Hide file tree
Showing 7 changed files with 641 additions and 62 deletions.
137 changes: 137 additions & 0 deletions rnd/market/market/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import prisma.models
import prisma.types
from prisma.errors import PrismaError
from fuzzywuzzy import fuzz


class AgentQueryError(Exception):
"""Custom exception for agent query errors"""

pass


async def get_agents(
page: int = 1,
page_size: int = 10,
name: str | None = None,
keyword: str | None = None,
category: str | None = None,
description: str | None = None,
description_threshold: int = 60,
sort_by: str = "createdAt",
sort_order: str = "desc",
):
"""
Retrieve a list of agents from the database based on the provided filters and pagination parameters.
Args:
page (int, optional): The page number to retrieve. Defaults to 1.
page_size (int, optional): The number of agents per page. Defaults to 10.
name (str, optional): Filter agents by name. Defaults to None.
keyword (str, optional): Filter agents by keyword. Defaults to None.
category (str, optional): Filter agents by category. Defaults to None.
description (str, optional): Filter agents by description. Defaults to None.
description_threshold (int, optional): The minimum fuzzy search threshold for the description. Defaults to 60.
sort_by (str, optional): The field to sort the agents by. Defaults to "createdAt".
sort_order (str, optional): The sort order ("asc" or "desc"). Defaults to "desc".
Returns:
dict: A dictionary containing the list of agents, total count, current page number, page size, and total number of pages.
"""
try:
# Define the base query
query = {}

# Add optional filters
if name:
query["name"] = {"contains": name, "mode": "insensitive"}
if keyword:
query["keywords"] = {"has": keyword}
if category:
query["categories"] = {"has": category}

# Define sorting
order = {sort_by: sort_order}

# Calculate pagination
skip = (page - 1) * page_size

# Execute the query
try:
agents = await prisma.models.Agents.prisma().find_many(
where=query, # type: ignore
order=order, # type: ignore
skip=skip,
take=page_size,
)
except PrismaError as e:
raise AgentQueryError(f"Database query failed: {str(e)}")

# Apply fuzzy search on description if provided
if description:
try:
filtered_agents = []
for agent in agents:
if (
agent.description
and fuzz.partial_ratio(
description.lower(), agent.description.lower()
)
>= description_threshold
):
filtered_agents.append(agent)
agents = filtered_agents
except AttributeError as e:
raise AgentQueryError(f"Error during fuzzy search: {str(e)}")

# Get total count for pagination info
total_count = len(agents)

return {
"agents": agents,
"total_count": total_count,
"page": page,
"page_size": page_size,
"total_pages": (total_count + page_size - 1) // page_size,
}

except AgentQueryError as e:
# Log the error or handle it as needed
raise e
except ValueError as e:
raise AgentQueryError(f"Invalid input parameter: {str(e)}")
except Exception as e:
# Catch any other unexpected exceptions
raise AgentQueryError(f"Unexpected error occurred: {str(e)}")


async def get_agent_details(agent_id: str, version: int | None = None):
"""
Retrieve agent details from the database.
Args:
agent_id (str): The ID of the agent.
version (int | None, optional): The version of the agent. Defaults to None.
Returns:
dict: The agent details.
Raises:
AgentQueryError: If the agent is not found or if there is an error querying the database.
"""
try:
query = {"id": agent_id}
if version is not None:
query["version"] = version # type: ignore

agent = await prisma.models.Agents.prisma().find_first(where=query) # type: ignore

if not agent:
raise AgentQueryError("Agent not found")

return agent

except PrismaError as e:
raise AgentQueryError(f"Database query failed: {str(e)}")
except Exception as e:
raise AgentQueryError(f"Unexpected error occurred: {str(e)}")
77 changes: 77 additions & 0 deletions rnd/market/market/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pydantic
import typing


class AgentResponse(pydantic.BaseModel):
"""
Represents a response from an agent.
Attributes:
id (str): The ID of the agent.
name (str, optional): The name of the agent.
description (str, optional): The description of the agent.
author (str, optional): The author of the agent.
keywords (list[str]): The keywords associated with the agent.
categories (list[str]): The categories the agent belongs to.
version (int): The version of the agent.
createdAt (str): The creation date of the agent.
updatedAt (str): The last update date of the agent.
"""

id: str
name: typing.Optional[str]
description: typing.Optional[str]
author: typing.Optional[str]
keywords: list[str]
categories: list[str]
version: int
createdAt: str
updatedAt: str


class AgentListResponse(pydantic.BaseModel):
"""
Represents a response containing a list of agents.
Attributes:
agents (list[AgentResponse]): The list of agents.
total_count (int): The total count of agents.
page (int): The current page number.
page_size (int): The number of agents per page.
total_pages (int): The total number of pages.
"""

agents: list[AgentResponse]
total_count: int
page: int
page_size: int
total_pages: int


class AgentDetailResponse(pydantic.BaseModel):
"""
Represents the response data for an agent detail.
Attributes:
id (str): The ID of the agent.
name (Optional[str]): The name of the agent.
description (Optional[str]): The description of the agent.
author (Optional[str]): The author of the agent.
keywords (List[str]): The keywords associated with the agent.
categories (List[str]): The categories the agent belongs to.
version (int): The version of the agent.
createdAt (str): The creation date of the agent.
updatedAt (str): The last update date of the agent.
graph (Dict[str, Any]): The graph data of the agent.
"""

id: str
name: typing.Optional[str]
description: typing.Optional[str]
author: typing.Optional[str]
keywords: list[str]
categories: list[str]
version: int
createdAt: str
updatedAt: str
graph: dict[str, typing.Any]
101 changes: 100 additions & 1 deletion rnd/market/market/routes/agents.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,102 @@
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException, Query, Path
from typing import Optional
from market.db import get_agents, get_agent_details, AgentQueryError
import market.model

router = APIRouter()


@router.get("/agents", response_model=market.model.AgentListResponse)
async def list_agents(
page: int = Query(1, ge=1, description="Page number"),
page_size: int = Query(10, ge=1, le=100, description="Number of items per page"),
name: Optional[str] = Query(None, description="Filter by agent name"),
keyword: Optional[str] = Query(None, description="Filter by keyword"),
category: Optional[str] = Query(None, description="Filter by category"),
description: Optional[str] = Query(None, description="Fuzzy search in description"),
description_threshold: int = Query(
60, ge=0, le=100, description="Fuzzy search threshold"
),
sort_by: str = Query("createdAt", description="Field to sort by"),
sort_order: str = Query("desc", description="Sort order (asc or desc)"),
):
"""
Retrieve a list of agents based on the provided filters.
Args:
page (int): Page number (default: 1).
page_size (int): Number of items per page (default: 10, min: 1, max: 100).
name (str, optional): Filter by agent name.
keyword (str, optional): Filter by keyword.
category (str, optional): Filter by category.
description (str, optional): Fuzzy search in description.
description_threshold (int): Fuzzy search threshold (default: 60, min: 0, max: 100).
sort_by (str): Field to sort by (default: "createdAt").
sort_order (str): Sort order (asc or desc) (default: "desc").
Returns:
market.model.AgentListResponse: A response containing the list of agents and pagination information.
Raises:
HTTPException: If there is a client error (status code 400) or an unexpected error (status code 500).
"""
try:
result = await get_agents(
page=page,
page_size=page_size,
name=name,
keyword=keyword,
category=category,
description=description,
description_threshold=description_threshold,
sort_by=sort_by,
sort_order=sort_order,
)

# Convert the result to the response model
agents = [
market.model.AgentResponse(**agent.dict()) for agent in result["agents"]
]

return market.model.AgentListResponse(
agents=agents,
total_count=result["total_count"],
page=result["page"],
page_size=result["page_size"],
total_pages=result["total_pages"],
)

except AgentQueryError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail="An unexpected error occurred")


@router.get("/agents/{agent_id}", response_model=market.model.AgentDetailResponse)
async def get_agent_details_endpoint(
agent_id: str = Path(..., description="The ID of the agent to retrieve"),
version: Optional[int] = Query(None, description="Specific version of the agent"),
):
"""
Retrieve details of a specific agent.
Args:
agent_id (str): The ID of the agent to retrieve.
version (Optional[int]): Specific version of the agent (default: None).
Returns:
market.model.AgentDetailResponse: The response containing the agent details.
Raises:
HTTPException: If the agent is not found or an unexpected error occurs.
"""
try:
agent = await get_agent_details(agent_id, version)
return market.model.AgentDetailResponse(**agent.model_dump())

except AgentQueryError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
)
Loading

0 comments on commit d1187b5

Please sign in to comment.