Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use pure ASGI middlewares #145

Merged
merged 12 commits into from
Jun 25, 2024
87 changes: 54 additions & 33 deletions fastapi_sqla/async_sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from typing import Annotated

import structlog
from fastapi import Depends, Request
from fastapi import Depends, Request, Response
from fastapi.responses import PlainTextResponse
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession as SqlaAsyncSession
from sqlalchemy.orm.session import sessionmaker
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from fastapi_sqla import aws_aurora_support, aws_rds_iam_support
from fastapi_sqla.sqla import _DEFAULT_SESSION_KEY, Base, new_engine
Expand Down Expand Up @@ -88,9 +89,7 @@ async def open_session(
await session.close()


async def add_session_to_request(
request: Request, call_next, key: str = _DEFAULT_SESSION_KEY
):
class AsyncSessionMiddleware:
"""Middleware which injects a new sqla async session into every request.

Handles creation of session, as well as commit, rollback, and closing of session.
Expand All @@ -108,36 +107,58 @@ async def add_session_to_request(
async def get_users(session: fastapi_sqla.AsyncSession):
return await session.execute(...) # use your session here
"""
async with open_session(key) as session:
setattr(request.state, f"{_ASYNC_REQUEST_SESSION_KEY}_{key}", session)
response = await call_next(request)

is_dirty = bool(session.dirty or session.deleted or session.new)

# try to commit after response, so that we can return a proper 500 response
# and not raise a true internal server error
if response.status_code < 400:
try:
await session.commit()
except Exception:
logger.exception("commit failed, returning http error")
response = PlainTextResponse(
content="Internal Server Error", status_code=500
)

if response.status_code >= 400:
# If ever a route handler returns an http exception, we do not want the
# session opened by current context manager to commit anything in db.
if is_dirty:
# optimistically only log if there were uncommitted changes
logger.warning(
"http error, rolling back possibly uncommitted changes",
status_code=response.status_code,
)
# since this is no-op if session is not dirty, we can always call it
await session.rollback()

return response
def __init__(self, app: ASGIApp, key: str = _DEFAULT_SESSION_KEY) -> None:
self.app = app
self.key = key

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)

async with open_session(self.key) as session:
request = Request(scope=scope, receive=receive, send=send)
setattr(request.state, f"{_ASYNC_REQUEST_SESSION_KEY}_{self.key}", session)

async def send_wrapper(message: Message) -> None:
if message["type"] != "http.response.start":
return await send(message)

response: Response | None = None
status_code = message["status"]
is_dirty = bool(session.dirty or session.deleted or session.new)

# try to commit after response, so that we can return a proper 500
# and not raise a true internal server error
if status_code < 400:
try:
await session.commit()
except Exception:
logger.exception("commit failed, returning http error")
status_code = 500
response = PlainTextResponse(
content="Internal Server Error", status_code=500
)

if status_code >= 400:
# If ever a route handler returns an http exception,
# we do not want the current session to commit anything in db.
if is_dirty:
# optimistically only log if there were uncommitted changes
logger.warning(
"http error, rolling back possibly uncommitted changes",
status_code=status_code,
)
# since this is no-op if the session is not dirty,
# we can always call it
await session.rollback()

if response:
return await response(scope, receive, send)

return await send(message)

await self.app(scope, receive, send_wrapper)


class AsyncSessionDependency:
Expand Down
16 changes: 4 additions & 12 deletions fastapi_sqla/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,9 @@ def setup_middlewares(app: FastAPI):
engines = {key: sqla.new_engine(key) for key in engine_keys}
for key, engine in engines.items():
if not _is_async_dialect(engine):
app.middleware("http")(
functools.partial(sqla.add_session_to_request, key=key)
)
app.add_middleware(sqla.SessionMiddleware, key=key)
else:
app.middleware("http")(
functools.partial(async_sqla.add_session_to_request, key=key)
)
app.add_middleware(async_sqla.AsyncSessionMiddleware, key=key)


@deprecated(
Expand All @@ -54,16 +50,12 @@ def setup(app: FastAPI):
for key, engine in engines.items():
if not _is_async_dialect(engine):
app.add_event_handler("startup", functools.partial(sqla.startup, key=key))
app.middleware("http")(
functools.partial(sqla.add_session_to_request, key=key)
)
app.add_middleware(sqla.SessionMiddleware, key=key)
else:
app.add_event_handler(
"startup", functools.partial(async_sqla.startup, key=key)
)
app.middleware("http")(
functools.partial(async_sqla.add_session_to_request, key=key)
)
app.add_middleware(async_sqla.AsyncSessionMiddleware, key=key)


def _get_engine_keys() -> set[str]:
Expand Down
94 changes: 57 additions & 37 deletions fastapi_sqla/sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from typing import Annotated

import structlog
from fastapi import Depends, Request
from fastapi import Depends, Request, Response
from fastapi.concurrency import contextmanager_in_threadpool
from fastapi.responses import PlainTextResponse
from sqlalchemy import engine_from_config, text
from sqlalchemy.engine import Engine
from sqlalchemy.ext.declarative import DeferredReflection
from sqlalchemy.orm.session import Session as SqlaSession
from sqlalchemy.orm.session import sessionmaker
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from fastapi_sqla import aws_aurora_support, aws_rds_iam_support

Expand Down Expand Up @@ -110,9 +111,7 @@ def open_session(key: str = _DEFAULT_SESSION_KEY) -> Generator[SqlaSession, None
session.close()


async def add_session_to_request(
request: Request, call_next, key: str = _DEFAULT_SESSION_KEY
):
class SessionMiddleware:
"""Middleware which injects a new sqla session into every request.

Handles creation of session, as well as commit, rollback, and closing of session.
Expand All @@ -130,39 +129,60 @@ async def add_session_to_request(
def get_users(session: fastapi_sqla.Session):
return session.execute(...) # use your session here
"""
async with contextmanager_in_threadpool(open_session(key)) as session:
setattr(request.state, f"{_REQUEST_SESSION_KEY}_{key}", session)

response = await call_next(request)

is_dirty = bool(session.dirty or session.deleted or session.new)

loop = asyncio.get_running_loop()

# try to commit after response, so that we can return a proper 500 response
# and not raise a true internal server error
if response.status_code < 400:
try:
await loop.run_in_executor(None, session.commit)
except Exception:
logger.exception("commit failed, returning http error")
response = PlainTextResponse(
content="Internal Server Error", status_code=500
)

if response.status_code >= 400:
# If ever a route handler returns an http exception, we do not want the
# session opened by current context manager to commit anything in db.
if is_dirty:
# optimistically only log if there were uncommitted changes
logger.warning(
"http error, rolling back possibly uncommitted changes",
status_code=response.status_code,
)
# since this is no-op if session is not dirty, we can always call it
await loop.run_in_executor(None, session.rollback)

return response

def __init__(self, app: ASGIApp, key: str = _DEFAULT_SESSION_KEY) -> None:
self.app = app
self.key = key

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)

async with contextmanager_in_threadpool(open_session(self.key)) as session:
request = Request(scope=scope, receive=receive, send=send)
setattr(request.state, f"{_REQUEST_SESSION_KEY}_{self.key}", session)

async def send_wrapper(message: Message) -> None:
if message["type"] != "http.response.start":
return await send(message)

response: Response | None = None
status_code = message["status"]
is_dirty = bool(session.dirty or session.deleted or session.new)

loop = asyncio.get_running_loop()

# try to commit after response, so that we can return a proper 500
# and not raise a true internal server error
if status_code < 400:
try:
await loop.run_in_executor(None, session.commit)
except Exception:
logger.exception("commit failed, returning http error")
status_code = 500
response = PlainTextResponse(
content="Internal Server Error", status_code=status_code
)

if status_code >= 400:
# If ever a route handler returns an http exception,
# we do not want the current session to commit anything in db.
if is_dirty:
# optimistically only log if there were uncommitted changes
logger.warning(
"http error, rolling back possibly uncommitted changes",
status_code=status_code,
)
# since this is no-op if the session is not dirty,
# we can always call it
await loop.run_in_executor(None, session.rollback)

if response:
return await response(scope, receive, send)

return await send(message)

await self.app(scope, receive, send_wrapper)


class SessionDependency:
Expand Down
51 changes: 13 additions & 38 deletions tests/test_base_setup_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,15 @@ def test_setup_middlewares_multiple_engines(db_url):
):
setup_middlewares(app)

assert app.middleware.call_count == 2
assert all(call.args[0] == "http" for call in app.middleware.call_args_list)

assert app.middleware.return_value.call_count == 2
assert any(
call
for call in app.middleware.return_value.call_args_list
if call.args[0].func == sqla.add_session_to_request
and call.args[0].keywords == {"key": _DEFAULT_SESSION_KEY}
)
assert any(
call
for call in app.middleware.return_value.call_args_list
if call.args[0].func == sqla.add_session_to_request
and call.args[0].keywords == {"key": read_only_key}
assert app.add_middleware.call_count == 2
assert all(
call.args[0] == sqla.SessionMiddleware
for call in app.add_middleware.call_args_list
)

app.add_middleware.assert_any_call(sqla.SessionMiddleware, key=_DEFAULT_SESSION_KEY)
app.add_middleware.assert_any_call(sqla.SessionMiddleware, key=read_only_key)


@mark.sqlalchemy("1.4")
@mark.require_asyncpg
Expand All @@ -49,21 +41,10 @@ def test_setup_middlewares_with_sync_and_async_sqlalchemy_url(async_session_key)
app = Mock()
setup_middlewares(app)

assert app.middleware.call_count == 2
assert all(call.args[0] == "http" for call in app.middleware.call_args_list)

assert app.middleware.return_value.call_count == 2
assert any(
call
for call in app.middleware.return_value.call_args_list
if call.args[0].func == sqla.add_session_to_request
and call.args[0].keywords == {"key": _DEFAULT_SESSION_KEY}
)
assert any(
call
for call in app.middleware.return_value.call_args_list
if call.args[0].func == async_sqla.add_session_to_request
and call.args[0].keywords == {"key": async_session_key}
assert app.add_middleware.call_count == 2
app.add_middleware.assert_any_call(sqla.SessionMiddleware, key=_DEFAULT_SESSION_KEY)
app.add_middleware.assert_any_call(
async_sqla.AsyncSessionMiddleware, key=async_session_key
)


Expand All @@ -79,12 +60,6 @@ def test_setup_middlewares_with_async_default_sqlalchemy_url(async_sqlalchemy_ur
):
setup_middlewares(app)

app.middleware.assert_called_once_with("http")
app.middleware.return_value.assert_called_once()
assert (
app.middleware.return_value.call_args.args[0].func
== async_sqla.add_session_to_request
app.add_middleware.assert_called_once_with(
async_sqla.AsyncSessionMiddleware, key=_DEFAULT_SESSION_KEY
)
assert app.middleware.return_value.call_args.args[0].keywords == {
"key": _DEFAULT_SESSION_KEY
}
Loading