diff --git a/examples/aiohttp_app.py b/examples/aiohttp_app.py index a25b9d4b..1b3641b7 100644 --- a/examples/aiohttp_app.py +++ b/examples/aiohttp_app.py @@ -24,7 +24,6 @@ class GatewayProvider(Provider): GatewayDepends = Annotated[Gateway, Depends()] -app = Application() router = RouteTableDef() @@ -34,7 +33,7 @@ async def endpoint(request: str, gateway: GatewayDepends) -> Response: data = await gateway.get() return Response(text=f'gateway data: {data}') - +app = Application() app.add_routes(router) -setup_dishka(GatewayProvider(), app=app) +setup_dishka(providers=[GatewayProvider()], app=app) run_app(app) diff --git a/src/dishka/integrations/aiohttp.py b/src/dishka/integrations/aiohttp.py index 17fe83dc..d863e6fe 100644 --- a/src/dishka/integrations/aiohttp.py +++ b/src/dishka/integrations/aiohttp.py @@ -1,4 +1,8 @@ -from typing import Callable, Final +__all__ = [ + "Depends", "inject", "setup_dishka", +] + +from typing import Callable, Final, Sequence from aiohttp import web from aiohttp.typedefs import Handler @@ -7,17 +11,17 @@ from aiohttp.web_response import StreamResponse from dishka import Provider, make_async_container -from dishka.async_container import AsyncContextWrapper -from dishka.integrations.base import wrap_injection +from dishka.async_container import AsyncContainer, AsyncContextWrapper +from dishka.integrations.base import Depends, wrap_injection -CONTAINER_KEY: Final = 'dishka_container' +CONTAINER_KEY: Final = web.AppKey('dishka_container', AsyncContainer) def inject(func: Callable) -> Callable: return wrap_injection( func=func, remove_depends=True, - container_getter=lambda p, _: p[0].app[CONTAINER_KEY], + container_getter=lambda p, _: p[0][CONTAINER_KEY], is_async=True, ) @@ -26,16 +30,16 @@ def inject(func: Callable) -> Callable: async def container_middleware( request: Request, handler: Handler, ) -> StreamResponse: - container = request.app['__container__'] - async with container() as container_: - request.app[CONTAINER_KEY] = container_ + container = request.app[CONTAINER_KEY] + async with container(context={Request: request}) as request_container: + request[CONTAINER_KEY] = request_container res = await handler(request) return res def startup(wrapper_container: AsyncContextWrapper): async def wrapper(app: Application) -> None: - app['__container__'] = await wrapper_container.__aenter__() + app[CONTAINER_KEY] = await wrapper_container.__aenter__() return wrapper @@ -45,8 +49,8 @@ async def wrapper(app: Application) -> None: return wrapper -def setup_dishka(*provides: Provider, app: Application) -> None: - wrapper_container = make_async_container(*provides) +def setup_dishka(providers: Sequence[Provider], app: Application) -> None: + wrapper_container = make_async_container(*providers) app.middlewares.append(container_middleware) app.on_startup.append(startup(wrapper_container)) app.on_shutdown.append(shutdown(wrapper_container)) diff --git a/tests/integrations/aiohttp/__init__.py b/tests/integrations/aiohttp/__init__.py new file mode 100644 index 00000000..0e1409fd --- /dev/null +++ b/tests/integrations/aiohttp/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("aiohttp") diff --git a/tests/integrations/aiohttp/test_aiohttp.py b/tests/integrations/aiohttp/test_aiohttp.py new file mode 100644 index 00000000..c40b656d --- /dev/null +++ b/tests/integrations/aiohttp/test_aiohttp.py @@ -0,0 +1,86 @@ +from contextlib import asynccontextmanager +from typing import Annotated +from unittest.mock import Mock + +import pytest +from aiohttp.test_utils import TestServer, TestClient +from aiohttp.web_app import Application +from aiohttp.web_response import Response +from aiohttp.web_routedef import RouteTableDef + +from dishka.integrations.aiohttp import ( + Depends, + setup_dishka, + inject, +) +from ..common import ( + APP_DEP_VALUE, + REQUEST_DEP_VALUE, + AppDep, + AppProvider, + RequestDep, +) + + +@asynccontextmanager +async def dishka_app(view, provider) -> TestClient: + app = Application() + + router = RouteTableDef() + router.get("/")(inject(view)) + + app.add_routes(router) + setup_dishka(providers=[provider], app=app) + client = TestClient(TestServer(app)) + await client.start_server() + yield client + await client.close() + + +async def get_with_app( + _, + a: Annotated[AppDep, Depends()], + mock: Annotated[Mock, Depends()], +) -> Response: + mock(a) + return Response(text=f'passed') + + +@pytest.mark.asyncio +async def test_app_dependency(app_provider: AppProvider): + async with dishka_app(get_with_app, app_provider) as client: + await client.get("/") + app_provider.mock.assert_called_with(APP_DEP_VALUE) + app_provider.app_released.assert_not_called() + app_provider.app_released.assert_called() + + +async def get_with_request( + _, + a: Annotated[RequestDep, Depends()], + mock: Annotated[Mock, Depends()], +) -> Response: + mock(a) + return Response(text=f'passed') + + + +@pytest.mark.asyncio +async def test_request_dependency(app_provider: AppProvider): + async with dishka_app(get_with_request, app_provider) as client: + await client.get("/") + app_provider.mock.assert_called_with(REQUEST_DEP_VALUE) + app_provider.request_released.assert_called_once() + + +@pytest.mark.asyncio +async def test_request_dependency2(app_provider: AppProvider): + async with dishka_app(get_with_request, app_provider) as client: + await client.get("/") + app_provider.mock.assert_called_with(REQUEST_DEP_VALUE) + app_provider.mock.reset_mock() + app_provider.request_released.assert_called_once() + app_provider.request_released.reset_mock() + await client.get("/") + app_provider.mock.assert_called_with(REQUEST_DEP_VALUE) + app_provider.request_released.assert_called_once() diff --git a/tox.ini b/tox.ini index 6bb95903..3d6c92be 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,7 @@ env_list = unit, real_world_example, fastapi-{0096,0109}, + aiohttp-393, flask-302, litestar-230, aiogram-330, @@ -16,6 +17,8 @@ addopts = --cov=dishka --cov-append --cov-report term-missing -v [testenv] deps = + aiohttp-393: -r requirements/aiohttp-393.txt + aiohttp-latest: -r requirements/aiohttp-latest.txt fastapi-latest: -r requirements/fastapi-latest.txt fastapi-0096: -r requirements/fastapi-0096.txt fastapi-0109: -r requirements/fastapi-0109.txt @@ -31,6 +34,7 @@ deps = starlette-0270: -r requirements/starlette-0270.txt commands = + aiohttp: pytest tests/integrations/aiohttp fastapi: pytest tests/integrations/fastapi aiogram: pytest tests/integrations/aiogram telebot: pytest tests/integrations/telebot