Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic deps and more tests #11

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
166 changes: 166 additions & 0 deletions examples/benchmarks/fastapi_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import logging
from contextlib import asynccontextmanager
from inspect import Parameter
from typing import Annotated, Callable, Iterable, NewType, get_type_hints

import uvicorn
from fastapi import APIRouter
from fastapi import Depends as FastapiDepends
from fastapi import FastAPI, Request

from dishka import Provider, Scope, make_async_container, provide
from dishka.inject import Depends, wrap_injection


# framework level
def inject(func):
hints = get_type_hints(func)
requests_param = next(
(name for name, hint in hints.items() if hint is Request),
None,
)
if requests_param:
getter = lambda kwargs: kwargs[requests_param].state.container
additional_params = []
else:
getter = lambda kwargs: kwargs["___r___"].state.container
additional_params = [Parameter(
name="___r___",
annotation=Request,
kind=Parameter.KEYWORD_ONLY,
)]

return wrap_injection(
func=func,
remove_depends=True,
container_getter=getter,
additional_params=additional_params,
is_async=True,
)


def container_middleware():
async def add_request_container(request: Request, call_next):
async with request.app.state.container(
{Request: request}
) as subcontainer:
request.state.container = subcontainer
return await call_next(request)

return add_request_container


class Stub:
def __init__(self, dependency: Callable, **kwargs):
self._dependency = dependency
self._kwargs = kwargs

def __call__(self):
raise NotImplementedError

def __eq__(self, other) -> bool:
if isinstance(other, Stub):
return (
self._dependency == other._dependency
and self._kwargs == other._kwargs
)
else:
if not self._kwargs:
return self._dependency == other
return False

def __hash__(self):
if not self._kwargs:
return hash(self._dependency)
serial = (
self._dependency,
*self._kwargs.items(),
)
return hash(serial)


# app dependency logic

Host = NewType("Host", str)


class B:
def __init__(self, x: int):
pass


class C:
def __init__(self, x: int):
pass


class A:
def __init__(self, b: B, c: C):
pass


MyInt = NewType("MyInt", int)


class MyProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_a(self, b: B, c: C) -> A:
return A(b, c)

@provide(scope=Scope.REQUEST)
async def get_b(self) -> Iterable[B]:
yield B(1)

@provide(scope=Scope.REQUEST)
async def get_c(self) -> Iterable[C]:
yield C(1)


# app
router = APIRouter()


@router.get("/")
@inject
async def index(
*,
value: Annotated[A, Depends()],
value2: Annotated[A, Depends()],
) -> str:
return f"{value} {value is value2}"


@router.get("/f")
async def index(
*,
value: Annotated[A, FastapiDepends(Stub(A))],
value2: Annotated[A, FastapiDepends(Stub(A))],
) -> str:
return f"{value} {value is value2}"


def new_a(b: B = FastapiDepends(Stub(B)), c: C = FastapiDepends(Stub(C))):
return A(b, c)


@asynccontextmanager
async def lifespan(app: FastAPI):
async with make_async_container(MyProvider(), with_lock=True) as container:
app.state.container = container
yield


def create_app() -> FastAPI:
logging.basicConfig(level=logging.WARNING)

app = FastAPI(lifespan=lifespan)
app.middleware("http")(container_middleware())
app.dependency_overrides[A] = new_a
app.dependency_overrides[B] = lambda: B(1)
app.dependency_overrides[C] = lambda: C(1)
app.include_router(router)
return app


if __name__ == "__main__":
uvicorn.run(create_app(), host="0.0.0.0", port=8000)
File renamed without changes.
File renamed without changes.
113 changes: 35 additions & 78 deletions examples/fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import logging
from abc import abstractmethod
from contextlib import asynccontextmanager
from inspect import Parameter
from typing import Annotated, Callable, Iterable, NewType, get_type_hints
from typing import (
Annotated, get_type_hints, Protocol, Any, get_origin,
get_args,
)

import uvicorn
from fastapi import APIRouter
from fastapi import Depends as FastapiDepends
from fastapi import FastAPI, Request

from dishka import Provider, Scope, make_async_container, provide
from dishka.inject import Depends, wrap_injection
from dishka import (
Depends, wrap_injection, Provider, Scope, make_async_container, provide,
)


# framework level
Expand Down Expand Up @@ -50,102 +54,58 @@ async def add_request_container(request: Request, call_next):
return add_request_container


class Stub:
def __init__(self, dependency: Callable, **kwargs):
self._dependency = dependency
self._kwargs = kwargs

def __call__(self):
# app core
class DbGateway(Protocol):
@abstractmethod
def get(self) -> str:
raise NotImplementedError

def __eq__(self, other) -> bool:
if isinstance(other, Stub):
return (
self._dependency == other._dependency
and self._kwargs == other._kwargs
)
else:
if not self._kwargs:
return self._dependency == other
return False

def __hash__(self):
if not self._kwargs:
return hash(self._dependency)
serial = (
self._dependency,
*self._kwargs.items(),
)
return hash(serial)


# app dependency logic

Host = NewType("Host", str)

class FakeDbGateway(DbGateway):
def get(self) -> str:
return "Hello"

class B:
def __init__(self, x: int):
pass

class Interactor:
def __init__(self, db: DbGateway):
self.db = db

class C:
def __init__(self, x: int):
pass
def __call__(self) -> str:
return self.db.get()


class A:
def __init__(self, b: B, c: C):
pass


MyInt = NewType("MyInt", int)


class MyProvider(Provider):
# app dependency logic
class AdaptersProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_a(self, b: B, c: C) -> A:
return A(b, c)
def get_db(self) -> DbGateway:
return FakeDbGateway()

@provide(scope=Scope.REQUEST)
async def get_b(self) -> Iterable[B]:
yield B(1)

@provide(scope=Scope.REQUEST)
async def get_c(self) -> Iterable[C]:
yield C(1)
class InteractorProvider(Provider):
i1 = provide(Interactor, scope=Scope.REQUEST)


# app
# presentation layer
router = APIRouter()


@router.get("/")
@inject
async def index(
*,
value: Annotated[A, Depends()],
value2: Annotated[A, Depends()],
interactor: Annotated[Interactor, Depends()],
) -> str:
return f"{value} {value is value2}"


@router.get("/f")
async def index(
*,
value: Annotated[A, FastapiDepends(Stub(A))],
value2: Annotated[A, FastapiDepends(Stub(A))],
) -> str:
return f"{value} {value is value2}"


def new_a(b: B = FastapiDepends(Stub(B)), c: C = FastapiDepends(Stub(C))):
return A(b, c)
result = interactor()
return result


# app configuration
@asynccontextmanager
async def lifespan(app: FastAPI):
async with make_async_container(MyProvider(), with_lock=True) as container:
async with make_async_container(
AdaptersProvider(), InteractorProvider(),
with_lock=True,
) as container:
app.state.container = container
yield

Expand All @@ -155,9 +115,6 @@ def create_app() -> FastAPI:

app = FastAPI(lifespan=lifespan)
app.middleware("http")(container_middleware())
app.dependency_overrides[A] = new_a
app.dependency_overrides[B] = lambda: B(1)
app.dependency_overrides[C] = lambda: C(1)
app.include_router(router)
return app

Expand Down
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
ruff
pytest
pytest-asyncio
pytest-repeat
2 changes: 1 addition & 1 deletion src/dishka/async_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
):
self.registry = registry
self.child_registries = child_registries
self.context = {}
self.context = {type(self): self}
if context:
self.context.update(context)
self.parent_container = parent_container
Expand Down
2 changes: 1 addition & 1 deletion src/dishka/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
):
self.registry = registry
self.child_registries = child_registries
self.context = {}
self.context = {type(self): self}
if context:
self.context.update(context)
self.parent_container = parent_container
Expand Down
Loading
Loading