Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-44763: Stop using db_session_dependency in UWS #177

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions src/vocutouts/uws/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
objects.
"""

from collections.abc import AsyncIterator
from typing import Annotated

from fastapi import Depends, Form, Request
from safir.arq import ArqMode, ArqQueue, MockArqQueue, RedisArqQueue
from safir.dependencies.db_session import db_session_dependency
from safir.database import create_async_session, create_database_engine
from safir.dependencies.logger import logger_dependency
from sqlalchemy.ext.asyncio import async_scoped_session
from sqlalchemy.ext.asyncio import AsyncEngine, async_scoped_session
from structlog.stdlib import BoundLogger

from .config import UWSConfig
Expand All @@ -32,7 +33,27 @@


class UWSFactory:
"""Build UWS components."""
"""Build UWS components.

Parameters
----------
config
UWS configuration.
arq
arq queue to use.
session
Database session.
result_store
Signed URL generator for results.
logger
Logger to use.

Attributes
----------
session
Database session. This is exposed primarily for the test suite. It
shouldn't be necessary for other code to use it directly.
"""

def __init__(
self,
Expand All @@ -43,9 +64,9 @@ def __init__(
result_store: ResultStore,
logger: BoundLogger,
) -> None:
self.session = session
self._config = config
self._arq = arq
self._session = session
self._result_store = result_store
self._logger = logger

Expand All @@ -65,7 +86,7 @@ def create_job_service(self) -> JobService:

def create_job_store(self) -> JobStore:
"""Create a new UWS job store."""
return JobStore(self._session)
return JobStore(self.session)

def create_templates(self) -> UWSTemplates:
"""Create a new XML renderer for responses."""
Expand All @@ -76,28 +97,33 @@ class UWSDependency:
"""Initializes UWS and provides a UWS factory as a dependency."""

def __init__(self) -> None:
self._arq: ArqQueue
self._config: UWSConfig
self._engine: AsyncEngine
self._session: async_scoped_session
self._result_store: ResultStore
self._arq: ArqQueue

async def __call__(
self,
session: Annotated[
async_scoped_session, Depends(db_session_dependency)
],
logger: Annotated[BoundLogger, Depends(logger_dependency)],
) -> UWSFactory:
return UWSFactory(
config=self._config,
arq=self._arq,
session=session,
result_store=self._result_store,
logger=logger,
)
self, logger: Annotated[BoundLogger, Depends(logger_dependency)]
) -> AsyncIterator[UWSFactory]:
try:
yield UWSFactory(
config=self._config,
arq=self._arq,
session=self._session,
result_store=self._result_store,
logger=logger,
)
finally:
# Following the recommendations in the SQLAlchemy documentation,
# each session is scoped to a single web request. However, this
# all uses the same async_scoped_session object, so should share
# an underlying engine and connection pool.
await self._session.remove()

async def aclose(self) -> None:
"""Shut down the UWS subsystem."""
await db_session_dependency.aclose()
await self._engine.dispose()

async def initialize(self, config: UWSConfig) -> None:
"""Initialize the UWS subsystem.
Expand All @@ -114,11 +140,12 @@ async def initialize(self, config: UWSConfig) -> None:
self._arq = await RedisArqQueue.initialize(settings)
else:
self._arq = MockArqQueue()
await db_session_dependency.initialize(
self._engine = create_database_engine(
config.database_url,
config.database_password,
isolation_level="REPEATABLE READ",
)
self._session = await create_async_session(self._engine)

def override_arq_queue(self, arq_queue: ArqQueue) -> None:
"""Change the arq used in subsequent invocations.
Expand Down
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from safir.arq import MockArqQueue
from safir.dependencies.db_session import db_session_dependency
from safir.testing.gcs import MockStorageClient, patch_google_storage
from safir.testing.slack import MockSlackWebhook, mock_slack_webhook

Expand Down Expand Up @@ -84,5 +83,5 @@ async def uws_factory(app: FastAPI) -> AsyncIterator[UWSFactory]:
already been initialized.
"""
logger = structlog.get_logger("vocutouts")
async for session in db_session_dependency():
yield await uws_dependency(session, logger)
async for factory in uws_dependency(logger):
yield factory
20 changes: 4 additions & 16 deletions tests/uws/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
from fastapi import APIRouter, FastAPI
from httpx import ASGITransport, AsyncClient
from safir.arq import MockArqQueue
from safir.dependencies.db_session import db_session_dependency
from safir.middleware.ivoa import CaseInsensitiveQueryMiddleware
from safir.middleware.x_forwarded import XForwardedMiddleware
from safir.slack.webhook import SlackRouteErrorHandler
from safir.testing.gcs import MockStorageClient, patch_google_storage
from safir.testing.slack import MockSlackWebhook, mock_slack_webhook
from sqlalchemy.ext.asyncio import async_scoped_session
from structlog.stdlib import BoundLogger

from vocutouts.uws.app import UWSApplication
Expand Down Expand Up @@ -108,24 +106,14 @@ def runner(uws_factory: UWSFactory, arq_queue: MockArqQueue) -> MockJobRunner:
return MockJobRunner(uws_factory, arq_queue)


@pytest_asyncio.fixture
async def session(app: FastAPI) -> AsyncIterator[async_scoped_session]:
"""Return a database session with no transaction open.

Depends on the ``app`` fixture to ensure that the database layer has
already been initialized.
"""
async for session in db_session_dependency():
yield session


@pytest.fixture
def uws_config() -> UWSConfig:
return build_uws_config()


@pytest_asyncio.fixture
async def uws_factory(
session: async_scoped_session, logger: BoundLogger
) -> UWSFactory:
return await uws_dependency(session, logger)
app: FastAPI, logger: BoundLogger
) -> AsyncIterator[UWSFactory]:
async for factory in uws_dependency(logger):
yield factory
9 changes: 3 additions & 6 deletions tests/uws/job_list_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from safir.database import datetime_to_db
from safir.datetime import current_datetime, isodatetime
from sqlalchemy import update
from sqlalchemy.ext.asyncio import async_scoped_session

from vocutouts.uws.dependencies import UWSFactory
from vocutouts.uws.models import UWSJobParameter
Expand Down Expand Up @@ -81,9 +80,7 @@


@pytest.mark.asyncio
async def test_job_list(
client: AsyncClient, session: async_scoped_session, uws_factory: UWSFactory
) -> None:
async def test_job_list(client: AsyncClient, uws_factory: UWSFactory) -> None:
job_service = uws_factory.create_job_service()
jobs = [
await job_service.create(
Expand All @@ -108,7 +105,7 @@ async def test_job_list(

# Adjust the creation time of the jobs so that searches are more
# interesting.
async with session.begin():
async with uws_factory.session.begin():
for i, job in enumerate(jobs):
hours = (2 - i) * 2
creation = current_datetime() - timedelta(hours=hours)
Expand All @@ -117,7 +114,7 @@ async def test_job_list(
.where(SQLJob.id == int(job.job_id))
.values(creation_time=datetime_to_db(creation))
)
await session.execute(stmt)
await uws_factory.session.execute(stmt)
job.creation_time = creation

# Retrieve the job list and check it.
Expand Down