diff --git a/src/vocutouts/uws/dependencies.py b/src/vocutouts/uws/dependencies.py index e67406d..68908cc 100644 --- a/src/vocutouts/uws/dependencies.py +++ b/src/vocutouts/uws/dependencies.py @@ -6,13 +6,14 @@ objects. """ +from collections.abc import AsyncIterator from typing import Annotated from fastapi import Depends, Form, Request from safir.arq import ArqMode, ArqQueue, MockArqQueue, RedisArqQueue -from safir.dependencies.db_session import db_session_dependency +from safir.database import create_async_session, create_database_engine from safir.dependencies.logger import logger_dependency -from sqlalchemy.ext.asyncio import async_scoped_session +from sqlalchemy.ext.asyncio import AsyncEngine, async_scoped_session from structlog.stdlib import BoundLogger from .config import UWSConfig @@ -32,7 +33,27 @@ class UWSFactory: - """Build UWS components.""" + """Build UWS components. + + Parameters + ---------- + config + UWS configuration. + arq + arq queue to use. + session + Database session. + result_store + Signed URL generator for results. + logger + Logger to use. + + Attributes + ---------- + session + Database session. This is exposed primarily for the test suite. It + shouldn't be necessary for other code to use it directly. + """ def __init__( self, @@ -43,9 +64,9 @@ def __init__( result_store: ResultStore, logger: BoundLogger, ) -> None: + self.session = session self._config = config self._arq = arq - self._session = session self._result_store = result_store self._logger = logger @@ -65,7 +86,7 @@ def create_job_service(self) -> JobService: def create_job_store(self) -> JobStore: """Create a new UWS job store.""" - return JobStore(self._session) + return JobStore(self.session) def create_templates(self) -> UWSTemplates: """Create a new XML renderer for responses.""" @@ -76,28 +97,33 @@ class UWSDependency: """Initializes UWS and provides a UWS factory as a dependency.""" def __init__(self) -> None: + self._arq: ArqQueue self._config: UWSConfig + self._engine: AsyncEngine + self._session: async_scoped_session self._result_store: ResultStore - self._arq: ArqQueue async def __call__( - self, - session: Annotated[ - async_scoped_session, Depends(db_session_dependency) - ], - logger: Annotated[BoundLogger, Depends(logger_dependency)], - ) -> UWSFactory: - return UWSFactory( - config=self._config, - arq=self._arq, - session=session, - result_store=self._result_store, - logger=logger, - ) + self, logger: Annotated[BoundLogger, Depends(logger_dependency)] + ) -> AsyncIterator[UWSFactory]: + try: + yield UWSFactory( + config=self._config, + arq=self._arq, + session=self._session, + result_store=self._result_store, + logger=logger, + ) + finally: + # Following the recommendations in the SQLAlchemy documentation, + # each session is scoped to a single web request. However, this + # all uses the same async_scoped_session object, so should share + # an underlying engine and connection pool. + await self._session.remove() async def aclose(self) -> None: """Shut down the UWS subsystem.""" - await db_session_dependency.aclose() + await self._engine.dispose() async def initialize(self, config: UWSConfig) -> None: """Initialize the UWS subsystem. @@ -114,11 +140,12 @@ async def initialize(self, config: UWSConfig) -> None: self._arq = await RedisArqQueue.initialize(settings) else: self._arq = MockArqQueue() - await db_session_dependency.initialize( + self._engine = create_database_engine( config.database_url, config.database_password, isolation_level="REPEATABLE READ", ) + self._session = await create_async_session(self._engine) def override_arq_queue(self, arq_queue: ArqQueue) -> None: """Change the arq used in subsequent invocations. diff --git a/tests/conftest.py b/tests/conftest.py index b929c4b..9615efb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,6 @@ from fastapi import FastAPI from httpx import ASGITransport, AsyncClient from safir.arq import MockArqQueue -from safir.dependencies.db_session import db_session_dependency from safir.testing.gcs import MockStorageClient, patch_google_storage from safir.testing.slack import MockSlackWebhook, mock_slack_webhook @@ -84,5 +83,5 @@ async def uws_factory(app: FastAPI) -> AsyncIterator[UWSFactory]: already been initialized. """ logger = structlog.get_logger("vocutouts") - async for session in db_session_dependency(): - yield await uws_dependency(session, logger) + async for factory in uws_dependency(logger): + yield factory diff --git a/tests/uws/conftest.py b/tests/uws/conftest.py index 315efb3..b05a7e6 100644 --- a/tests/uws/conftest.py +++ b/tests/uws/conftest.py @@ -14,13 +14,11 @@ from fastapi import APIRouter, FastAPI from httpx import ASGITransport, AsyncClient from safir.arq import MockArqQueue -from safir.dependencies.db_session import db_session_dependency from safir.middleware.ivoa import CaseInsensitiveQueryMiddleware from safir.middleware.x_forwarded import XForwardedMiddleware from safir.slack.webhook import SlackRouteErrorHandler from safir.testing.gcs import MockStorageClient, patch_google_storage from safir.testing.slack import MockSlackWebhook, mock_slack_webhook -from sqlalchemy.ext.asyncio import async_scoped_session from structlog.stdlib import BoundLogger from vocutouts.uws.app import UWSApplication @@ -108,17 +106,6 @@ def runner(uws_factory: UWSFactory, arq_queue: MockArqQueue) -> MockJobRunner: return MockJobRunner(uws_factory, arq_queue) -@pytest_asyncio.fixture -async def session(app: FastAPI) -> AsyncIterator[async_scoped_session]: - """Return a database session with no transaction open. - - Depends on the ``app`` fixture to ensure that the database layer has - already been initialized. - """ - async for session in db_session_dependency(): - yield session - - @pytest.fixture def uws_config() -> UWSConfig: return build_uws_config() @@ -126,6 +113,7 @@ def uws_config() -> UWSConfig: @pytest_asyncio.fixture async def uws_factory( - session: async_scoped_session, logger: BoundLogger -) -> UWSFactory: - return await uws_dependency(session, logger) + app: FastAPI, logger: BoundLogger +) -> AsyncIterator[UWSFactory]: + async for factory in uws_dependency(logger): + yield factory diff --git a/tests/uws/job_list_test.py b/tests/uws/job_list_test.py index 7c95828..1da7533 100644 --- a/tests/uws/job_list_test.py +++ b/tests/uws/job_list_test.py @@ -13,7 +13,6 @@ from safir.database import datetime_to_db from safir.datetime import current_datetime, isodatetime from sqlalchemy import update -from sqlalchemy.ext.asyncio import async_scoped_session from vocutouts.uws.dependencies import UWSFactory from vocutouts.uws.models import UWSJobParameter @@ -81,9 +80,7 @@ @pytest.mark.asyncio -async def test_job_list( - client: AsyncClient, session: async_scoped_session, uws_factory: UWSFactory -) -> None: +async def test_job_list(client: AsyncClient, uws_factory: UWSFactory) -> None: job_service = uws_factory.create_job_service() jobs = [ await job_service.create( @@ -108,7 +105,7 @@ async def test_job_list( # Adjust the creation time of the jobs so that searches are more # interesting. - async with session.begin(): + async with uws_factory.session.begin(): for i, job in enumerate(jobs): hours = (2 - i) * 2 creation = current_datetime() - timedelta(hours=hours) @@ -117,7 +114,7 @@ async def test_job_list( .where(SQLJob.id == int(job.job_id)) .values(creation_time=datetime_to_db(creation)) ) - await session.execute(stmt) + await uws_factory.session.execute(stmt) job.creation_time = creation # Retrieve the job list and check it.