Skip to content

Commit

Permalink
Merge pull request #201 from IvanKirpichnikov/with-parents
Browse files Browse the repository at this point in the history
reworked WithParents
  • Loading branch information
Tishka17 authored Aug 8, 2024
2 parents dd09825 + d1fe88d commit 0c62051
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 111 deletions.
194 changes: 94 additions & 100 deletions src/dishka/entities/with_parents.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,142 +32,136 @@
)
TypeVarsMap: TypeAlias = dict[TypeHint, TypeHint]


def has_orig_bases(obj: TypeHint) -> bool:
return hasattr(obj, "__orig_bases__")

if HAS_PY_311:
def is_type_var_tuple(obj: TypeHint) -> bool:
return getattr(obj, "__typing_is_unpacked_typevartuple__", False)
else:
def is_type_var_tuple(obj: TypeHint) -> bool:
return False

def is_ignored_type(origin_obj: TypeHint) -> bool:
return origin_obj in IGNORE_TYPES


def get_filled_arguments(obj: TypeHint) -> list[TypeHint]:
filled_arguments = []
for arg in get_generic_args(obj):
if isinstance(arg, TypeVar):
continue
if is_type_var_tuple(arg):
continue
filled_arguments.append(arg)
return filled_arguments
def has_orig_bases(obj: TypeHint) -> bool:
return hasattr(obj, "__orig_bases__")


def create_type_vars_map(obj: TypeHint) -> TypeVarsMap:
origin_obj = strip_alias(obj)
if not get_type_vars(origin_obj):
return {}
def is_ignored_type(origin_type: TypeHint) -> bool:
return origin_type in IGNORE_TYPES

type_vars = list(get_type_vars(origin_obj))
filled_arguments = get_filled_arguments(obj)

if not filled_arguments or not type_vars:
def create_type_vars_map(obj):
origin_obj = strip_alias(obj)
type_vars = list(get_type_vars(origin_obj) or get_type_vars(obj))
if not type_vars:
return {}

type_vars_map = {}
arguments = list(get_generic_args(obj))
reversed_arguments = False
while True:
if len(type_vars) == 0:
break

type_var = type_vars[0]
if isinstance(type_var, TypeVar):
del type_vars[0]
type_vars_map[type_var] = filled_arguments.pop(0)
type_vars_map[type_var] = arguments.pop(0)
else:
if len(type_vars) == 1:
if reversed_arguments:
filled_arguments.reverse()
type_vars_map[type_var] = filled_arguments
arguments.reverse()
type_vars_map[type_var] = arguments
break
type_vars.reverse()
filled_arguments.reverse()
arguments.reverse()
reversed_arguments = not reversed_arguments

return type_vars_map


def create_type(
obj: TypeHint,
type_vars_map: TypeVarsMap,
) -> TypeHint:
origin_obj = strip_alias(obj)
type_vars = get_type_vars(origin_obj) or get_type_vars(obj)
if not type_vars:
return origin_obj

generic_args = []
for type_var in type_vars:
arg = type_vars_map[type_var]
if isinstance(arg, list):
generic_args.extend(arg)
else:
generic_args.append(arg)
return origin_obj[tuple(generic_args)]

class ParentsResolver:
def get_parents(self, child_type: TypeHint) -> list[TypeHint]:
if is_ignored_type(strip_alias(child_type)):
raise ValueError(
f"The starting class {child_type!r} is in ignored types",
)
if is_parametrized(child_type) or has_orig_bases(child_type):
return self._get_parents_for_generic(child_type)
return self._get_parents_for_mro(child_type)

def recursion_get_parents_for_generic_class(
obj: TypeHint,
parents: list[TypeHint],
type_vars_map: TypeVarsMap,
) -> None:
origin_obj = strip_alias(obj)
if not has_orig_bases(origin_obj):
parents.extend(get_parents_for_mro(origin_obj))
return

for obj_ in origin_obj.__orig_bases__:
origin_obj = strip_alias(obj_)
if is_ignored_type(origin_obj):
continue

type_vars_map.update(create_type_vars_map(obj_))
parents.append(create_type(obj_, type_vars_map))
recursion_get_parents_for_generic_class(
obj_,
parents,
type_vars_map.copy(),
)


def get_parents_for_mro(obj: TypeHint) -> list[TypeHint]:
return [
obj_ for obj_ in obj.mro()
if not is_ignored_type(strip_alias(obj_))
]


def get_parents(obj: TypeHint) -> list[TypeHint]:
if is_ignored_type(strip_alias(obj)):
raise ValueError(f"The starting class {obj!r} is in ignored types")

if is_parametrized(obj):
type_vars_map = create_type_vars_map(obj)
parents = [
create_type(
obj=obj,
type_vars_map=type_vars_map,
),
]
recursion_get_parents_for_generic_class(
obj=obj,
parents=parents,
type_vars_map=type_vars_map,
)
elif has_orig_bases(obj):
parents = [obj]
recursion_get_parents_for_generic_class(
obj=obj,
def _get_parents_for_generic(
self, child_type: TypeHint,
) -> list[TypeHint]:
parents = []
self._recursion_get_parents(
child_type=child_type,
parents=parents,
type_vars_map={},
)
else:
parents = get_parents_for_mro(obj)
return parents
return parents

def _recursion_get_parents(
self,
child_type: TypeHint,
parents: list[TypeHint],
type_vars_map: TypeVarsMap,
) -> None:
origin_child_type = strip_alias(child_type)
parametrized = is_parametrized(child_type)
orig_bases = has_orig_bases(origin_child_type)
if not orig_bases and not parametrized:
parents.extend(
self._get_parents_for_mro(origin_child_type),
)
return

new_type_vars_map = create_type_vars_map(child_type)
new_type_vars_map.update(type_vars_map)
parents.append(
self._create_type(
obj=child_type,
type_vars_map=new_type_vars_map,
),
)
if not orig_bases:
return
for parent_type in origin_child_type.__orig_bases__:
origin_parent_type = strip_alias(parent_type)
if is_ignored_type(origin_parent_type):
continue

self._recursion_get_parents(
child_type=parent_type,
parents=parents,
type_vars_map=new_type_vars_map,
)

def _get_parents_for_mro(
self, child_type: TypeHint,
) -> list[TypeHint]:
return [
parent_type for parent_type in child_type.mro()
if not is_ignored_type(strip_alias(parent_type))
]

def _create_type(
self,
obj: TypeHint,
type_vars_map: TypeVarsMap,
) -> TypeHint:
origin_obj = strip_alias(obj)
type_vars = get_type_vars(origin_obj) or get_type_vars(obj)
if not type_vars:
return obj

generic_args = []
for type_var in type_vars:
arg = type_vars_map[type_var]
if isinstance(arg, list):
generic_args.extend(arg)
else:
generic_args.append(arg)
return origin_obj[tuple(generic_args)]


if TYPE_CHECKING:
Expand All @@ -177,7 +171,7 @@ class WithParents:
def __class_getitem__(
cls, item: TypeHint,
) -> TypeHint | ProvideMultiple:
parents = get_parents(item)
parents = ParentsResolver().get_parents(item)
if len(parents) > 1:
return ProvideMultiple(parents)
return parents[0]
44 changes: 33 additions & 11 deletions tests/unit/container/test_with_parents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from abc import ABC
from typing import Annotated, Any, Generic, Protocol, TypeVar
from collections.abc import Sequence
from typing import (
Annotated,
Any,
Generic,
Protocol,
TypeVar,
)

import pytest

Expand All @@ -10,11 +17,11 @@
Scope,
make_container,
)
from dishka._adaptix.common import TypeHint
from dishka._adaptix.feature_requirement import HAS_PY_311
from dishka.entities.with_parents import (
ParentsResolver,
WithParents,
get_filled_arguments,
get_parents,
is_type_var_tuple,
)
from dishka.exceptions import NoFactoryError
Expand Down Expand Up @@ -188,15 +195,30 @@ class A2(A1[T], Generic[T]): ...
is container.get(float)
)


def test_using_ignoring_type() -> None:
with pytest.raises(ValueError): # noqa: PT011
get_parents(object)
ParentsResolver().get_parents(object)

@pytest.mark.skipif(
not HAS_PY_311,
reason="test for python >= 3.11",
)
def test_ignore_get_filled_arguments() -> None:
class Test(Generic[T, Unpack[Ts]]): ...

assert not get_filled_arguments(Test[T, Unpack[Ts]])
def test_ignoring_parent() -> None:
class A(Generic[T]): ...
assert ParentsResolver().get_parents(A[int]) == [A[int]]


class TupleGeneric(tuple[T], Generic[T]): ... # noqa: SLOT001
class SequenceInt(Sequence[int]): ...
class ListAny(list[Any]): ...
class JsonMapping(dict[str, str | int]): ...

@pytest.mark.parametrize(
("structure", "result"),
[
(TupleGeneric[str], [TupleGeneric[str], tuple[str]]),
(SequenceInt, [SequenceInt, Sequence[int]]),
(ListAny, [ListAny, list[Any]]),
(JsonMapping, [JsonMapping, dict[str, str | int]]),
],
)
def test_structures(structure: TypeHint, result: list[TypeHint]) -> None:
assert ParentsResolver().get_parents(structure) == result

0 comments on commit 0c62051

Please sign in to comment.