Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Limit free users to 50 executions and sessions #865

Merged
merged 12 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ LITELLM_POSTGRES_PASSWORD=<your_litellm_postgres_password>
LITELLM_MASTER_KEY=<your_litellm_master_key>
LITELLM_SALT_KEY=<your_litellm_salt_key>
LITELLM_REDIS_PASSWORD=<your_litellm_redis_password>
MAX_FREE_SESSIONS=50
MAX_FREE_EXECUTIONS=50

# LLM Providers
# --------------
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, StrictBool

from .Chat import ChatSettings
from .Common import JinjaTemplate
from .Tools import (
ChosenBash20241022,
ChosenComputer20241022,
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/autogen/Tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
BaseModel,
ConfigDict,
Field,
RootModel,
StrictBool,
)

Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)

from ..autogen.openapi_model import TransitionTarget
from ..common.protocol.remote import RemoteList
from ..common.protocol.tasks import ExecutionInput
from ..common.retry_policies import DEFAULT_RETRY_POLICY
from ..common.storage_handler import store_in_blob_store_if_large
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@

api_key_header_name: str = env.str("AGENTS_API_KEY_HEADER_NAME", default="X-Auth-Key")

max_free_sessions: int = env.int("MAX_FREE_SESSIONS", default=50)
max_free_executions: int = env.int("MAX_FREE_EXECUTIONS", default=50)

# Litellm API
# -----------
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/models/execution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ruff: noqa: F401, F403, F405

from .count_executions import count_executions
from .create_execution import create_execution
from .create_execution_transition import create_execution_transition
from .get_execution import get_execution
Expand Down
61 changes: 61 additions & 0 deletions agents-api/agents_api/models/execution/count_executions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Any, TypeVar
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ..utils import (
cozo_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
verify_developer_owns_resource_query,
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(dict, one=True)
@cozo_query
@beartype
def count_executions(
*,
developer_id: UUID,
task_id: UUID,
) -> tuple[list[str], dict]:
count_query = """
input[task_id] <- [[to_uuid($task_id)]]

counter[count(id)] :=
input[task_id],
*executions {
task_id,
execution_id: id,
}

?[count] := counter[count]
"""

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id,
"tasks",
task_id=task_id,
parents=[("agents", "agent_id")],
),
count_query,
]

return (queries, {"task_id": str(task_id)})
1 change: 1 addition & 0 deletions agents-api/agents_api/models/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

# ruff: noqa: F401, F403, F405

from .count_sessions import count_sessions
from .create_or_update_session import create_or_update_session
from .create_session import create_session
from .delete_session import delete_session
Expand Down
64 changes: 64 additions & 0 deletions agents-api/agents_api/models/session/count_sessions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""This module contains functions for querying session data from the 'cozodb' database."""

from typing import Any, TypeVar
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ..utils import (
cozo_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(dict, one=True)
@cozo_query
@beartype
def count_sessions(
*,
developer_id: UUID,
) -> tuple[list[str], dict]:
"""
Counts sessions from the 'cozodb' database.

Parameters:
developer_id (UUID): The developer's ID to filter sessions by.
"""

count_query = """
input[developer_id] <- [[
to_uuid($developer_id),
]]

counter[count(id)] :=
input[developer_id],
*sessions{
developer_id,
session_id: id,
}

?[count] := counter[count]
"""

queries = [
verify_developer_id_query(developer_id),
count_query,
]

return (queries, {"developer_id": str(developer_id)})
16 changes: 14 additions & 2 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from typing import Annotated, Callable, Optional
from uuid import UUID, uuid4

from anthropic.types.beta.beta_message import BetaMessage
from fastapi import BackgroundTasks, Depends, Header
from fastapi import BackgroundTasks, Depends, Header, HTTPException, status
from starlette.status import HTTP_201_CREATED

from ...autogen.openapi_model import (
Expand All @@ -19,9 +18,11 @@
from ...common.utils.datetime import utcnow
from ...common.utils.template import render_template
from ...dependencies.developer_id import get_developer_data
from ...env import max_free_sessions
from ...models.chat.gather_messages import gather_messages
from ...models.chat.prepare_chat_context import prepare_chat_context
from ...models.entry.create_entries import create_entries
from ...models.session.count_sessions import count_sessions as count_sessions_query
from .metrics import total_tokens_per_user
from .router import router

Expand Down Expand Up @@ -54,6 +55,17 @@ async def chat(
ChatResponse: The chat response.
"""

# check if the developer is paid
if "paid" not in developer.tags:
# get the session length
sessions = count_sessions_query(developer_id=developer.id)
session_length = sessions["count"]
if session_length > max_free_sessions:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Session length exceeded the free tier limit",
)

if chat_input.stream:
raise NotImplementedError("Streaming is not yet implemented")

Expand Down
22 changes: 22 additions & 0 deletions agents-api/agents_api/routers/tasks/create_task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
UpdateExecutionRequest,
)
from ...clients.temporal import run_task_execution_workflow
from ...common.protocol.developers import Developer
from ...dependencies.developer_id import get_developer_id
from ...env import max_free_executions
from ...models.developer.get_developer import get_developer
from ...models.execution.count_executions import (
count_executions as count_executions_query,
)
from ...models.execution.create_execution import (
create_execution as create_execution_query,
)
Expand Down Expand Up @@ -113,6 +119,22 @@ async def create_task_execution(

raise

# get developer data
developer: Developer = get_developer(developer_id=x_developer_id)

# # check if the developer is paid
if "paid" not in developer.tags:
executions = count_executions_query(
developer_id=x_developer_id, task_id=task_id
)

execution_count = executions["count"]
if execution_count > max_free_executions:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Execution count exceeded the free tier limit",
)

execution, handle = await start_execution(
developer_id=x_developer_id,
task_id=task_id,
Expand Down
3 changes: 3 additions & 0 deletions agents-api/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ x--shared-environment: &shared-environment
S3_ENDPOINT: ${S3_ENDPOINT:-http://seaweedfs:8333}
S3_ACCESS_KEY: ${S3_ACCESS_KEY}
S3_SECRET_KEY: ${S3_SECRET_KEY}
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY}
MAX_FREE_SESSIONS: ${MAX_FREE_SESSIONS:-50}
MAX_FREE_EXECUTIONS: ${MAX_FREE_EXECUTIONS:-50}

x--base-agents-api: &base-agents-api
image: julepai/agents-api:${TAG:-dev}
Expand Down
18 changes: 18 additions & 0 deletions agents-api/tests/test_execution_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CreateTransitionRequest,
Execution,
)
from agents_api.models.execution.count_executions import count_executions
from agents_api.models.execution.create_execution import create_execution
from agents_api.models.execution.create_execution_transition import (
create_execution_transition,
Expand Down Expand Up @@ -91,6 +92,23 @@ def _(
assert result[0].status == "queued"


@test("model: count executions")
def _(
client=cozo_client,
developer_id=test_developer_id,
execution=test_execution,
task=test_task,
):
result = count_executions(
developer_id=developer_id,
task_id=task.id,
client=client,
)

assert isinstance(result, dict)
assert result["count"] > 0


@test("model: create execution transition")
def _(client=cozo_client, developer_id=test_developer_id, execution=test_execution):
result = create_execution_transition(
Expand Down
12 changes: 12 additions & 0 deletions agents-api/tests/test_session_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CreateSessionRequest,
Session,
)
from agents_api.models.session.count_sessions import count_sessions
from agents_api.models.session.create_or_update_session import create_or_update_session
from agents_api.models.session.create_session import create_session
from agents_api.models.session.delete_session import delete_session
Expand Down Expand Up @@ -120,6 +121,17 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
assert len(result) > 0


@test("model: count sessions")
def _(client=cozo_client, developer_id=test_developer_id, session=test_session):
result = count_sessions(
developer_id=developer_id,
client=client,
)

assert isinstance(result, dict)
assert result["count"] > 0


@test("model: create or update session")
def _(
client=cozo_client, developer_id=test_developer_id, agent=test_agent, user=test_user
Expand Down
Loading
Loading