Skip to content

Commit

Permalink
Merge pull request #1038 from julep-ai/x/connection-pool
Browse files Browse the repository at this point in the history
hotfix(agents-api): Fixes for connection pool getting exhausted
  • Loading branch information
creatorrr authored Jan 10, 2025
2 parents cf9ffab + e1a1bc8 commit a5f4132
Show file tree
Hide file tree
Showing 15 changed files with 59 additions and 85 deletions.
2 changes: 2 additions & 0 deletions agents-api/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
notebooks/

# Local database files
temporal.db
*.bak
Expand Down
26 changes: 0 additions & 26 deletions agents-api/agents_api/activities/container.py

This file was deleted.

8 changes: 3 additions & 5 deletions agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
from beartype import beartype
from temporalio import activity

from ..app import lifespan
from ..app import app
from ..autogen.openapi_model import BaseIntegrationDef
from ..clients import integrations
from ..common.exceptions.tools import IntegrationExecutionException
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..env import testing
from ..queries import tools
from .container import container


@lifespan(container)
@beartype
async def execute_integration(
context: StepContext,
Expand All @@ -40,15 +38,15 @@ async def execute_integration(
agent_id=agent_id,
task_id=task_id,
arg_type="args",
connection_pool=container.state.postgres_pool,
connection_pool=app.state.postgres_pool,
)

merged_tool_setup = await tools.get_tool_args_from_metadata(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
arg_type="setup",
connection_pool=container.state.postgres_pool,
connection_pool=app.state.postgres_pool,
)

arguments = merged_tool_args.get(tool_name, {}) | (integration.arguments or {}) | arguments
Expand Down
13 changes: 9 additions & 4 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Any
from uuid import UUID

Expand All @@ -6,7 +7,7 @@
from fastapi.background import BackgroundTasks
from temporalio import activity

from ..app import app, lifespan
from ..app import app
from ..autogen.openapi_model import (
ChatInput,
CreateDocRequest,
Expand All @@ -21,19 +22,20 @@
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..env import testing
from ..queries import developers
from .container import container
from .utils import get_handler


@lifespan(app, container) # Both are needed because we are using the routes
@beartype
async def execute_system(
context: StepContext,
system: SystemDef,
) -> Any:
"""Execute a system call with the appropriate handler and transformed arguments."""

arguments: dict[str, Any] = system.arguments or {}

connection_pool = getattr(app.state, "postgres_pool", None)

if not isinstance(context.execution_input, ExecutionInput):
msg = "Expected ExecutionInput type for context.execution_input"
raise TypeError(msg)
Expand All @@ -54,7 +56,9 @@ async def execute_system(
arguments[field] = UUID(arguments[field])

try:
# Partial with connection pool
handler = get_handler(system)
handler = partial(handler, connection_pool=connection_pool)

# Transform arguments for doc-related operations (except create and search
# as we're calling the endpoint function rather than the model method)
Expand Down Expand Up @@ -87,8 +91,9 @@ async def execute_system(
if system.operation == "chat" and system.resource == "session":
developer = await developers.get_developer(
developer_id=arguments["developer_id"],
connection_pool=container.state.postgres_pool,
connection_pool=connection_pool,
)

session_id = arguments.get("session_id")
x_custom_api_key = arguments.get("x_custom_api_key", None)
chat_input = ChatInput(**arguments)
Expand Down
6 changes: 2 additions & 4 deletions agents-api/agents_api/activities/task_steps/pg_query_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from temporalio import activity

from ... import queries
from ...app import lifespan
from ...app import app
from ...env import pg_dsn
from ..container import container


@activity.defn
@lifespan(container)
@beartype
async def pg_query_step(
query_name: str,
Expand All @@ -21,4 +19,4 @@ async def pg_query_step(

module = getattr(queries, module_name)
query = getattr(module, name)
return await query(**values, connection_pool=container.state.postgres_pool)
return await query(**values, connection_pool=app.state.postgres_pool)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi import HTTPException
from temporalio import activity

from ...app import lifespan
from ...app import app
from ...autogen.openapi_model import CreateTransitionRequest, Transition
from ...clients.temporal import get_workflow_handle
from ...common.protocol.tasks import ExecutionInput, StepContext
Expand All @@ -14,10 +14,8 @@
from ...queries.executions.create_execution_transition import (
create_execution_transition,
)
from ..container import container


@lifespan(container)
@beartype
async def transition_step(
context: StepContext,
Expand Down Expand Up @@ -50,7 +48,7 @@ async def transition_step(
execution_id=context.execution_input.execution.id,
data=transition_info,
task_token=transition_info.task_token,
connection_pool=container.state.postgres_pool,
connection_pool=app.state.postgres_pool,
)

except Exception as e:
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ def get_handler(system: SystemDef) -> Callable:
from ..queries.users.get_user import get_user as get_user_query
from ..queries.users.list_users import list_users as list_users_query
from ..queries.users.update_user import update_user as update_user_query

# FIXME: Do not use routes directly;
from ..routers.docs.create_doc import create_agent_doc, create_user_doc
from ..routers.docs.search_docs import search_agent_docs, search_user_docs
from ..routers.sessions.chat import chat
Expand Down
57 changes: 24 additions & 33 deletions agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scalar_fastapi import get_scalar_api_reference

from .clients.pg import create_db_pool
from .env import api_prefix, hostname, protocol, public_port
from .env import api_prefix, hostname, pool_max_size, protocol, public_port


class State(Protocol):
Expand All @@ -22,56 +22,47 @@ class ObjectWithState(Protocol):
state: State


pool = None


# TODO: This currently doesn't use env.py, we should move to using them
@asynccontextmanager
async def lifespan(*containers: FastAPI | ObjectWithState):
async def lifespan(container: FastAPI | ObjectWithState):
# INIT POSTGRES #
pg_dsn = os.environ.get("PG_DSN")

global pool
if not pool:
pool = await create_db_pool(pg_dsn)
pool = await create_db_pool(pg_dsn, max_size=pool_max_size)

for container in containers:
if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None):
container.state.postgres_pool = pool
if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None):
container.state.postgres_pool = pool

# INIT S3 #
s3_access_key = os.environ.get("S3_ACCESS_KEY")
s3_secret_key = os.environ.get("S3_SECRET_KEY")
s3_endpoint = os.environ.get("S3_ENDPOINT")

for container in containers:
if hasattr(container, "state") and not getattr(container.state, "s3_client", None):
session = get_session()
container.state.s3_client = await session.create_client(
"s3",
aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
endpoint_url=s3_endpoint,
).__aenter__()
if hasattr(container, "state") and not getattr(container.state, "s3_client", None):
session = get_session()
container.state.s3_client = await session.create_client(
"s3",
aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
endpoint_url=s3_endpoint,
).__aenter__()

try:
yield
finally:
# # CLOSE POSTGRES #
# for container in containers:
# if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
# pool = getattr(container.state, "postgres_pool", None)
# if pool:
# await pool.close()
# container.state.postgres_pool = None
# CLOSE POSTGRES #
if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
pool = getattr(container.state, "postgres_pool", None)
if pool:
await pool.close()
container.state.postgres_pool = None

# CLOSE S3 #
for container in containers:
if hasattr(container, "state") and getattr(container.state, "s3_client", None):
s3_client = getattr(container.state, "s3_client", None)
if s3_client:
await s3_client.close()
container.state.s3_client = None
if hasattr(container, "state") and getattr(container.state, "s3_client", None):
s3_client = getattr(container.state, "s3_client", None)
if s3_client:
await s3_client.close()
container.state.s3_client = None


app: FastAPI = FastAPI(
Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/clients/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ async def _init_conn(conn):
)


async def create_db_pool(dsn: str | None = None):
return await asyncpg.create_pool(dsn if dsn is not None else pg_dsn, init=_init_conn)
async def create_db_pool(dsn: str | None = None, **kwargs):
return await asyncpg.create_pool(
dsn if dsn is not None else pg_dsn, init=_init_conn, **kwargs
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
It utilizes the environs library for environment variable parsing.
"""

import multiprocessing
import random
from pprint import pprint
from typing import Any
Expand Down Expand Up @@ -60,6 +61,7 @@
summarization_model_name: str = env.str("SUMMARIZATION_MODEL_NAME", default="gpt-4-turbo")

query_timeout: float = env.float("QUERY_TIMEOUT", default=90.0)
pool_max_size: int = env.int("POOL_MAX_SIZE", default=multiprocessing.cpu_count())


# Auth
Expand Down
3 changes: 2 additions & 1 deletion agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated
from typing import Annotated, Any
from uuid import UUID

from fastapi import Depends
Expand All @@ -15,6 +15,7 @@ async def create_user_doc(
user_id: UUID,
data: CreateDocRequest,
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
connection_pool: Any = None, # FIXME: Placeholder that should be removed
) -> ResourceCreatedResponse:
"""
Creates a new document for a user.
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/routers/docs/search_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ async def search_user_docs(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
search_params: (TextOnlyDocSearchRequest | VectorDocSearchRequest | HybridDocSearchRequest),
user_id: UUID,
connection_pool: Any = None, # FIXME: Placeholder that should be removed
) -> DocSearchResponse:
"""
Searches for documents associated with a specific user.
Expand Down
3 changes: 2 additions & 1 deletion agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated
from typing import Annotated, Any
from uuid import UUID

from fastapi import BackgroundTasks, Depends, Header, HTTPException, status
Expand Down Expand Up @@ -40,6 +40,7 @@ async def chat(
chat_input: ChatInput,
background_tasks: BackgroundTasks,
x_custom_api_key: str | None = Header(None, alias="X-Custom-Api-Key"),
connection_pool: Any = None, # FIXME: Placeholder that should be removed
) -> ChatResponse:
"""
Initiates a chat session.
Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from tenacity import after_log, retry, retry_if_exception_type, wait_fixed

from ..app import app, lifespan
from ..clients import temporal
from .worker import create_worker

Expand All @@ -35,8 +36,9 @@ async def main():
client = await temporal.get_client_with_metrics()
worker = create_worker(client)

# Start the worker to listen for and process tasks
await worker.run()
async with lifespan(app):
# Start the worker to listen for and process tasks
await worker.run()


if __name__ == "__main__":
Expand Down
3 changes: 0 additions & 3 deletions memory-store/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@ services:
memory-store:
image: timescale/timescaledb-ha:pg17

# For timescaledb specific options,
# See: https://github.com/timescale/timescaledb-docker?tab=readme-ov-file#notes-on-timescaledb-tune
environment:
- POSTGRES_PASSWORD=${MEMORY_STORE_PASSWORD:-postgres}
- OPENAI_API_KEY=${OPENAI_API_KEY:?OPENAI_API_KEY is required}
- TS_TUNE_MAX_CONNS=${TS_TUNE_MAX_CONNS:-1000}
ports:
- "5432:5432"
volumes:
Expand Down

0 comments on commit a5f4132

Please sign in to comment.