Skip to content

Commit

Permalink
Merge pull request #29 from taskiq-python/feature/graph-dep
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius authored Sep 29, 2024
2 parents e05a6d4 + 02acbd8 commit 6caa86e
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 6 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ with graph.sync_ctx() as ctx:
The ParamInfo has the information about name and parameters signature. It's useful if you want to create a dependency that changes based on parameter name, or signature.


Also ParamInfo contains the initial graph that was used.

## Exception propagation

By default if error happens within the context, we send this error to the dependency,
Expand Down
13 changes: 10 additions & 3 deletions taskiq_dependencies/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ class BaseResolveContext:
def __init__(
self,
graph: "DependencyGraph",
main_graph: "DependencyGraph",
initial_cache: Optional[Dict[Any, Any]] = None,
exception_propagation: bool = True,
) -> None:
self.graph = graph
# Main graph that contains all the subgraphs.
self.main_graph = main_graph
self.opened_dependencies: List[Any] = []
self.sub_contexts: "List[Any]" = []
self.initial_cache = initial_cache or {}
Expand Down Expand Up @@ -89,7 +92,11 @@ def traverse_deps( # noqa: C901
# If the user want to get ParamInfo,
# we get declaration of the current dependency.
if subdep.dependency == ParamInfo:
kwargs[subdep.param_name] = ParamInfo(dep.param_name, dep.signature)
kwargs[subdep.param_name] = ParamInfo(
dep.param_name,
self.main_graph,
dep.signature,
)
continue
if subdep.use_cache:
# If this dependency can be calculated, using cache,
Expand Down Expand Up @@ -197,7 +204,7 @@ def resolver(self, executed_func: Any, initial_cache: Dict[Any, Any]) -> Any:
:return: dict with resolved kwargs.
"""
if getattr(executed_func, "dep_graph", False):
ctx = SyncResolveContext(executed_func, initial_cache)
ctx = SyncResolveContext(executed_func, self.main_graph, initial_cache)
self.sub_contexts.append(ctx)
sub_result = ctx.resolve_kwargs()
elif inspect.isgenerator(executed_func):
Expand Down Expand Up @@ -325,7 +332,7 @@ async def resolver(
:return: dict with resolved kwargs.
"""
if getattr(executed_func, "dep_graph", False):
ctx = AsyncResolveContext(executed_func, initial_cache) # type: ignore
ctx = AsyncResolveContext(executed_func, self.main_graph, initial_cache) # type: ignore
self.sub_contexts.append(ctx)
sub_result = await ctx.resolve_kwargs()
elif inspect.isgenerator(executed_func):
Expand Down
9 changes: 9 additions & 0 deletions taskiq_dependencies/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from taskiq_dependencies.ctx import AsyncResolveContext, SyncResolveContext
from taskiq_dependencies.dependency import Dependency
from taskiq_dependencies.utils import ParamInfo

try:
from fastapi.params import Depends as FastapiDepends
Expand Down Expand Up @@ -63,6 +64,7 @@ def async_ctx(
if replaced_deps:
graph = DependencyGraph(self.target, replaced_deps)
return AsyncResolveContext(
graph,
graph,
initial_cache,
exception_propagation,
Expand All @@ -89,6 +91,7 @@ def sync_ctx(
if replaced_deps:
graph = DependencyGraph(self.target, replaced_deps)
return SyncResolveContext(
graph,
graph,
initial_cache,
exception_propagation,
Expand Down Expand Up @@ -122,8 +125,14 @@ def _build_graph(self) -> None: # noqa: C901
continue
if dep.dependency is None:
continue
# If we have replaced dependencies, we need to replace
# them in the current dependency.
if self.replaced_deps and dep.dependency in self.replaced_deps:
dep.dependency = self.replaced_deps[dep.dependency]
# We can say for sure that ParamInfo doesn't have any dependencies,
# so we skip it.
if dep.dependency == ParamInfo:
continue
# Get signature and type hints.
origin = getattr(dep.dependency, "__origin__", None)
if origin is None:
Expand Down
7 changes: 6 additions & 1 deletion taskiq_dependencies/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import inspect
import sys
from contextlib import _AsyncGeneratorContextManager, _GeneratorContextManager
from typing import Any, AsyncContextManager, ContextManager, Optional
from typing import TYPE_CHECKING, Any, AsyncContextManager, ContextManager, Optional

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard

if TYPE_CHECKING:
from taskiq_dependencies.graph import DependencyGraph


class ParamInfo:
"""
Expand All @@ -23,9 +26,11 @@ class ParamInfo:
def __init__(
self,
name: str,
graph: "DependencyGraph",
signature: Optional[inspect.Parameter] = None,
) -> None:
self.name = name
self.graph = graph
self.definition = signature

def __repr__(self) -> str:
Expand Down
32 changes: 30 additions & 2 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,15 @@ def dep(info: ParamInfo = Depends()) -> ParamInfo:
def target(my_test_param: ParamInfo = Depends(dep)) -> None:
return None

with DependencyGraph(target=target).sync_ctx() as g:
graph = DependencyGraph(target=target)
with graph.sync_ctx() as g:
kwargs = g.resolve_kwargs()

info: ParamInfo = kwargs["my_test_param"]
assert info.name == "my_test_param"
assert info.definition
assert info.definition.annotation == ParamInfo
assert info.graph == graph


def test_param_info_no_dependant() -> None:
Expand All @@ -349,12 +351,14 @@ def test_param_info_no_dependant() -> None:
def target(info: ParamInfo = Depends()) -> None:
return None

with DependencyGraph(target=target).sync_ctx() as g:
graph = DependencyGraph(target=target)
with graph.sync_ctx() as g:
kwargs = g.resolve_kwargs()

info: ParamInfo = kwargs["info"]
assert info.name == ""
assert info.definition is None
assert info.graph == graph


def test_class_based_dependencies() -> None:
Expand Down Expand Up @@ -863,3 +867,27 @@ def target(acm: TestACM = Depends(get_test_acm)) -> None:
kwargs = await ctx.resolve_kwargs()
assert kwargs["acm"] == test_acm
assert not test_acm.opened


def test_param_info_subgraph() -> None:
"""
Test subgraphs for ParamInfo.
Test that correct graph is stored in ParamInfo
even if evaluated from subgraphs.
"""

def inner_dep(info: ParamInfo = Depends()) -> ParamInfo:
return info

def target(info: ParamInfo = Depends(inner_dep, use_cache=False)) -> None:
return None

graph = DependencyGraph(target=target)
with graph.sync_ctx() as g:
kwargs = g.resolve_kwargs()

info: ParamInfo = kwargs["info"]
assert info.name == ""
assert info.definition is None
assert info.graph == graph

0 comments on commit 6caa86e

Please sign in to comment.