Skip to content

Commit

Permalink
Merge pull request #13 from reagento/feature/decorator
Browse files Browse the repository at this point in the history
Feature/decorator
  • Loading branch information
Tishka17 authored Jan 26, 2024
2 parents 5994e0b + f38f16b commit d9c3822
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 45 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,20 @@ with make_container(provider) as container:
```python
class MyProvider(Provider):
p = alias(source=A, provides=AProtocol)
```
```
it works the same way as
```python
class MyProvider(Provider):
@provide(scope=<Scope of A>)
def p(self, a: A) -> AProtocol:
return a
```

* Want to apply decorator pattern and do not want to alter existing provide method? Use `decorate`. It will construct object using earlie defined provider and then pass it to your decorator before returning from the container.
```python
class MyProvider(Provider):
@decorate
def decorate_a(self, a: A) -> A:
return ADecorator(a)
```
Decorator function can also have additional parameters.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
dependencies = []

[project.urls]
"Homepage" = "https://github.com/tishka17/dishka"
"Bug Tracker" = "https://github.com/tishka17/dishka/issues"
"Homepage" = "https://github.com/reagento/dishka"
"Bug Tracker" = "https://github.com/reagento/dishka/issues"


7 changes: 2 additions & 5 deletions src/dishka/async_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable, List, Optional, Type, TypeVar

from .provider import DependencyProvider, Provider, ProviderType
from .registry import Registry, make_registry
from .registry import Registry, make_registries
from .scope import BaseScope, Scope

T = TypeVar("T")
Expand Down Expand Up @@ -145,10 +145,7 @@ def make_async_container(
context: Optional[dict] = None,
with_lock: bool = False,
) -> AsyncContextWrapper:
registries = [
make_registry(*providers, scope=scope)
for scope in scopes
]
registries = make_registries(*providers, scopes=scopes)
return AsyncContextWrapper(
AsyncContainer(*registries, context=context, with_lock=with_lock),
)
8 changes: 3 additions & 5 deletions src/dishka/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable, List, Optional, Type, TypeVar

from .provider import DependencyProvider, Provider, ProviderType
from .registry import Registry, make_registry
from .registry import Registry, make_registries
from .scope import BaseScope, Scope

T = TypeVar("T")
Expand Down Expand Up @@ -119,6 +119,7 @@ def close(self):


class ContextWrapper:
__slots__ = ("container",)
def __init__(self, container: Container):
self.container = container

Expand All @@ -135,10 +136,7 @@ def make_container(
context: Optional[dict] = None,
with_lock: bool = False,
) -> ContextWrapper:
registries = [
make_registry(*providers, scope=scope)
for scope in scopes
]
registries = make_registries(*providers, scopes=scopes)
return ContextWrapper(
Container(*registries, context=context, with_lock=with_lock),
)
75 changes: 59 additions & 16 deletions src/dishka/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import (
Any,
Callable,
List,
Optional,
Sequence,
Type,
Expand Down Expand Up @@ -71,16 +72,6 @@ def __get__(self, instance, owner):
is_to_bound=False,
)

def aliased(self, target: Type):
return DependencyProvider(
dependencies=[self.provides],
source=_identity,
provides=target,
scope=self.scope,
type=self.type,
is_to_bound=self.is_to_bound,
)


def make_dependency_provider(
provides: Any,
Expand Down Expand Up @@ -133,6 +124,19 @@ def __init__(self, source, provides):
self.source = source
self.provides = provides

def as_provider(self, scope: BaseScope) -> DependencyProvider:
return DependencyProvider(
scope=scope,
source=_identity,
provides=self.provides,
is_to_bound=False,
dependencies=[self.source],
type=ProviderType.FACTORY,
)

def __get__(self, instance, owner):
return self


def alias(
*,
Expand All @@ -145,6 +149,45 @@ def alias(
)


class Decorator:
__slots__ = ("provides", "provider")

def __init__(self, provider: DependencyProvider):
self.provider = provider
self.provides = provider.provides

def as_provider(
self, scope: BaseScope, new_dependency: Any,
) -> DependencyProvider:
return DependencyProvider(
scope=scope,
source=self.provider.source,
provides=self.provider.provides,
is_to_bound=self.provider.is_to_bound,
dependencies=[
new_dependency if dep is self.provides else dep
for dep in self.provider.dependencies
],
type=self.provider.type,
)

def __get__(self, instance, owner):
return Decorator(self.provider.__get__(instance, owner))


def decorate(
source: Union[None, Callable, Type] = None,
provides: Any = None,
):
if source is not None:
return Decorator(make_dependency_provider(provides, None, source))

def scoped(func):
return Decorator(make_dependency_provider(provides, None, func))

return scoped


def provide(
source: Union[None, Callable, Type] = None,
*,
Expand All @@ -160,12 +203,12 @@ def scoped(func):
return scoped


DependencyProviderVariant = Alias | DependencyProvider | Decorator


class Provider:
def __init__(self):
self.dependency_providers = {}
self.aliases = []
self.dependency_providers: List[DependencyProviderVariant] = []
for name, attr in vars(type(self)).items():
if isinstance(attr, DependencyProvider):
self.dependency_providers[attr.provides] = getattr(self, name)
elif isinstance(attr, Alias):
self.aliases.append(attr)
if isinstance(attr, DependencyProviderVariant):
self.dependency_providers.append(getattr(self, name))
50 changes: 36 additions & 14 deletions src/dishka/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any
from typing import Any, List, NewType, Type

from .provider import DependencyProvider, Provider
from .provider import Alias, Decorator, DependencyProvider, Provider
from .scope import BaseScope


Expand All @@ -14,22 +14,44 @@ def __init__(self, scope: BaseScope):
def add_provider(self, provider: DependencyProvider):
self._providers[provider.provides] = provider

def get_provider(self, dependency: Any):
def get_provider(self, dependency: Any) -> DependencyProvider:
return self._providers.get(dependency)


def make_registry(*providers: Provider, scope: BaseScope) -> Registry:
registry = Registry(scope)
def make_registries(
*providers: Provider, scopes: Type[BaseScope],
) -> List[Registry]:
dep_scopes = {}
for provider in providers:
for dependency_provider in provider.dependency_providers.values():
if dependency_provider.scope is scope:
registry.add_provider(dependency_provider)
for dep_provider in provider.dependency_providers:
if hasattr(dep_provider, "scope"):
dep_scopes[dep_provider.provides] = dep_provider.scope

registries = {scope: Registry(scope) for scope in scopes}

for provider in providers:
for alias in provider.aliases:
dependency_provider = registry.get_provider(alias.source)
if dependency_provider:
registry.add_provider(
dependency_provider.aliased(alias.provides),
for dep_provider in provider.dependency_providers:
if isinstance(dep_provider, DependencyProvider):
scope = dep_provider.scope
elif isinstance(dep_provider, Alias):
scope = dep_scopes[dep_provider.source]
dep_scopes[dep_provider.provides] = scope
dep_provider = dep_provider.as_provider(scope)
elif isinstance(dep_provider, Decorator):
scope = dep_scopes[dep_provider.provides]
registry = registries[scope]
undecorated_type = NewType(
f"Old_{dep_provider.provides.__name__}",
dep_provider.provides,
)
old_provider = registry.get_provider(dep_provider.provides)
old_provider.provides = undecorated_type
registry.add_provider(old_provider)
dep_provider = dep_provider.as_provider(
scope, undecorated_type,
)
return registry
else:
raise
registries[scope].add_provider(dep_provider)

return list(registries.values())
53 changes: 53 additions & 0 deletions tests/container/test_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from dishka import Provider, Scope, make_container, provide
from dishka.provider import alias, decorate


class A:
pass


class A1(A):
pass


class A2(A1):
pass


class ADecorator:
def __init__(self, a: A):
self.a = a


def test_simple():
class MyProvider(Provider):
a = provide(A, scope=Scope.APP)
ad = decorate(ADecorator, provides=A)

with make_container(MyProvider()) as container:
a = container.get(A)
assert isinstance(a, ADecorator)
assert isinstance(a.a, A)


def test_alias():
class MyProvider(Provider):
a2 = provide(A2, scope=Scope.APP)
a1 = alias(source=A2, provides=A1)
a = alias(source=A1, provides=A)

@decorate
def decorated(self, a: A1) -> A1:
return ADecorator(a)

with make_container(MyProvider()) as container:
a1 = container.get(A1)
assert isinstance(a1, ADecorator)
assert isinstance(a1.a, A2)

a2 = container.get(A2)
assert isinstance(a2, A2)
assert a2 is a1.a

a = container.get(A)
assert a is a1
3 changes: 1 addition & 2 deletions tests/test_privider.py → tests/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def foo(self, x: bool) -> str:
return f"{x}"

provider = MyProvider()
assert len(provider.dependency_providers) == 2
assert len(provider.aliases) == 1
assert len(provider.dependency_providers) == 3


@pytest.mark.parametrize(
Expand Down

0 comments on commit d9c3822

Please sign in to comment.