Skip to content

Commit

Permalink
dynamic deps and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tishka17 committed Jan 24, 2024
1 parent a13d5d9 commit 81abd73
Show file tree
Hide file tree
Showing 15 changed files with 502 additions and 152 deletions.
File renamed without changes.
166 changes: 166 additions & 0 deletions examples/benchmarks/fastapi_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import logging
from contextlib import asynccontextmanager
from inspect import Parameter
from typing import Annotated, Callable, Iterable, NewType, get_type_hints

import uvicorn
from fastapi import APIRouter
from fastapi import Depends as FastapiDepends
from fastapi import FastAPI, Request

from dishka import Provider, Scope, make_async_container, provide
from dishka.inject import Depends, wrap_injection


# framework level
def inject(func):
hints = get_type_hints(func)
requests_param = next(
(name for name, hint in hints.items() if hint is Request),
None,
)
if requests_param:
getter = lambda kwargs: kwargs[requests_param].state.container
additional_params = []
else:
getter = lambda kwargs: kwargs["___r___"].state.container
additional_params = [Parameter(
name="___r___",
annotation=Request,
kind=Parameter.KEYWORD_ONLY,
)]

return wrap_injection(
func=func,
remove_depends=True,
container_getter=getter,
additional_params=additional_params,
is_async=True,
)


def container_middleware():
async def add_request_container(request: Request, call_next):
async with request.app.state.container(
{Request: request}
) as subcontainer:
request.state.container = subcontainer
return await call_next(request)

return add_request_container


class Stub:
def __init__(self, dependency: Callable, **kwargs):
self._dependency = dependency
self._kwargs = kwargs

def __call__(self):
raise NotImplementedError

def __eq__(self, other) -> bool:
if isinstance(other, Stub):
return (
self._dependency == other._dependency
and self._kwargs == other._kwargs
)
else:
if not self._kwargs:
return self._dependency == other
return False

def __hash__(self):
if not self._kwargs:
return hash(self._dependency)
serial = (
self._dependency,
*self._kwargs.items(),
)
return hash(serial)


# app dependency logic

Host = NewType("Host", str)


class B:
def __init__(self, x: int):
pass


class C:
def __init__(self, x: int):
pass


class A:
def __init__(self, b: B, c: C):
pass


MyInt = NewType("MyInt", int)


class MyProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_a(self, b: B, c: C) -> A:
return A(b, c)

@provide(scope=Scope.REQUEST)
async def get_b(self) -> Iterable[B]:
yield B(1)

@provide(scope=Scope.REQUEST)
async def get_c(self) -> Iterable[C]:
yield C(1)


# app
router = APIRouter()


@router.get("/")
@inject
async def index(
*,
value: Annotated[A, Depends()],
value2: Annotated[A, Depends()],
) -> str:
return f"{value} {value is value2}"


@router.get("/f")
async def index(
*,
value: Annotated[A, FastapiDepends(Stub(A))],
value2: Annotated[A, FastapiDepends(Stub(A))],
) -> str:
return f"{value} {value is value2}"


def new_a(b: B = FastapiDepends(Stub(B)), c: C = FastapiDepends(Stub(C))):
return A(b, c)


@asynccontextmanager
async def lifespan(app: FastAPI):
async with make_async_container(MyProvider(), with_lock=True) as container:
app.state.container = container
yield


def create_app() -> FastAPI:
logging.basicConfig(level=logging.WARNING)

app = FastAPI(lifespan=lifespan)
app.middleware("http")(container_middleware())
app.dependency_overrides[A] = new_a
app.dependency_overrides[B] = lambda: B(1)
app.dependency_overrides[C] = lambda: C(1)
app.include_router(router)
return app


if __name__ == "__main__":
uvicorn.run(create_app(), host="0.0.0.0", port=8000)
File renamed without changes.
File renamed without changes.
113 changes: 35 additions & 78 deletions examples/fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import logging
from abc import abstractmethod
from contextlib import asynccontextmanager
from inspect import Parameter
from typing import Annotated, Callable, Iterable, NewType, get_type_hints
from typing import (
Annotated, get_type_hints, Protocol, Any, get_origin,
get_args,
)

import uvicorn
from fastapi import APIRouter
from fastapi import Depends as FastapiDepends
from fastapi import FastAPI, Request

from dishka import Provider, Scope, make_async_container, provide
from dishka.inject import Depends, wrap_injection
from dishka import (
Depends, wrap_injection, Provider, Scope, make_async_container, provide,
)


# framework level
Expand Down Expand Up @@ -50,102 +54,58 @@ async def add_request_container(request: Request, call_next):
return add_request_container


class Stub:
def __init__(self, dependency: Callable, **kwargs):
self._dependency = dependency
self._kwargs = kwargs

def __call__(self):
# app core
class DbGateway(Protocol):
@abstractmethod
def get(self) -> str:
raise NotImplementedError

def __eq__(self, other) -> bool:
if isinstance(other, Stub):
return (
self._dependency == other._dependency
and self._kwargs == other._kwargs
)
else:
if not self._kwargs:
return self._dependency == other
return False

def __hash__(self):
if not self._kwargs:
return hash(self._dependency)
serial = (
self._dependency,
*self._kwargs.items(),
)
return hash(serial)


# app dependency logic

Host = NewType("Host", str)

class FakeDbGateway(DbGateway):
def get(self) -> str:
return "Hello"

class B:
def __init__(self, x: int):
pass

class Interactor:
def __init__(self, db: DbGateway):
self.db = db

class C:
def __init__(self, x: int):
pass
def __call__(self) -> str:
return self.db.get()


class A:
def __init__(self, b: B, c: C):
pass


MyInt = NewType("MyInt", int)


class MyProvider(Provider):
# app dependency logic
class AdaptersProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_a(self, b: B, c: C) -> A:
return A(b, c)
def get_db(self) -> DbGateway:
return FakeDbGateway()

@provide(scope=Scope.REQUEST)
async def get_b(self) -> Iterable[B]:
yield B(1)

@provide(scope=Scope.REQUEST)
async def get_c(self) -> Iterable[C]:
yield C(1)
class InteractorProvider(Provider):
i1 = provide(Interactor, scope=Scope.REQUEST)


# app
# presentation layer
router = APIRouter()


@router.get("/")
@inject
async def index(
*,
value: Annotated[A, Depends()],
value2: Annotated[A, Depends()],
interactor: Annotated[Interactor, Depends()],
) -> str:
return f"{value} {value is value2}"


@router.get("/f")
async def index(
*,
value: Annotated[A, FastapiDepends(Stub(A))],
value2: Annotated[A, FastapiDepends(Stub(A))],
) -> str:
return f"{value} {value is value2}"


def new_a(b: B = FastapiDepends(Stub(B)), c: C = FastapiDepends(Stub(C))):
return A(b, c)
result = interactor()
return result


# app configuration
@asynccontextmanager
async def lifespan(app: FastAPI):
async with make_async_container(MyProvider(), with_lock=True) as container:
async with make_async_container(
AdaptersProvider(), InteractorProvider(),
with_lock=True,
) as container:
app.state.container = container
yield

Expand All @@ -155,9 +115,6 @@ def create_app() -> FastAPI:

app = FastAPI(lifespan=lifespan)
app.middleware("http")(container_middleware())
app.dependency_overrides[A] = new_a
app.dependency_overrides[B] = lambda: B(1)
app.dependency_overrides[C] = lambda: C(1)
app.include_router(router)
return app

Expand Down
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
ruff
pytest
pytest-asyncio
pytest-repeat
2 changes: 1 addition & 1 deletion src/dishka/async_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
):
self.registry = registry
self.child_registries = child_registries
self.context = {}
self.context = {type(self): self}
if context:
self.context.update(context)
self.parent_container = parent_container
Expand Down
2 changes: 1 addition & 1 deletion src/dishka/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
):
self.registry = registry
self.child_registries = child_registries
self.context = {}
self.context = {type(self): self}
if context:
self.context.update(context)
self.parent_container = parent_container
Expand Down
Loading

0 comments on commit 81abd73

Please sign in to comment.