diff --git a/asgiref/sync.py b/asgiref/sync.py index 3710a7f1..435aa814 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -7,7 +7,9 @@ import warnings import weakref from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Callable, Dict, Optional, overload +from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, overload + +from typing_extensions import ParamSpec from .compatibility import current_task, get_running_loop from .current_thread_executor import CurrentThreadExecutor @@ -18,6 +20,9 @@ else: contextvars = None +P = ParamSpec("P") +R = TypeVar("R") + def _restore_context(context): # Check for changes in contextvars, and set them to the current @@ -118,7 +123,9 @@ class AsyncToSync: # Local, not a threadlocal, so that tasks can work out what their parent used. executors = Local() - def __init__(self, awaitable, force_new_loop=False): + def __init__( + self, awaitable: Callable[..., Awaitable[Any]], force_new_loop: bool = False + ): if not callable(awaitable) or not _iscoroutinefunction_or_partial(awaitable): # Python does not have very reliable detection of async functions # (lots of false negatives) so this is just a warning. @@ -127,7 +134,7 @@ def __init__(self, awaitable, force_new_loop=False): ) self.awaitable = awaitable try: - self.__self__ = self.awaitable.__self__ + self.__self__ = self.awaitable.__self__ # type: ignore except AttributeError: pass if force_new_loop: @@ -507,8 +514,36 @@ def get_current_task(): return None -# Lowercase aliases (and decorator friendliness) -async_to_sync = AsyncToSync +# Lowercase aliases (and decorator/typing friendliness) +@overload +def async_to_sync( + func: None = None, + force_new_loop: bool = False, +) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, R]]: + ... + + +@overload +def async_to_sync( + func: Callable[P, Awaitable[R]], + force_new_loop: bool = False, +) -> Callable[P, R]: + ... + + +def async_to_sync( + func: Optional[Callable[P, Awaitable[R]]] = None, + force_new_loop: bool = False, +) -> Union[Callable[P, R], Callable[[Callable[P, Awaitable[R]]], Callable[P, R]]]: + if func is None: + return lambda f: AsyncToSync( + f, + force_new_loop=force_new_loop, + ) + return AsyncToSync( + func, + force_new_loop=force_new_loop, + ) @overload @@ -516,24 +551,26 @@ def sync_to_async( func: None = None, thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, -) -> Callable[[Callable[..., Any]], SyncToAsync]: +) -> Callable[[Callable[P, R]], Callable[P, Awaitable[R]]]: ... @overload def sync_to_async( - func: Callable[..., Any], + func: Callable[P, R], thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, -) -> SyncToAsync: +) -> Callable[P, Awaitable[R]]: ... def sync_to_async( - func=None, - thread_sensitive=True, - executor=None, -): + func: Optional[Callable[P, R]] = None, + thread_sensitive: bool = True, + executor: Optional["ThreadPoolExecutor"] = None, +) -> Union[ + Callable[P, Awaitable[R]], Callable[[Callable[P, R]], Callable[P, Awaitable[R]]] +]: if func is None: return lambda f: SyncToAsync( f, diff --git a/setup.cfg b/setup.cfg index 01d0b50b..62bd85c7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,7 +38,7 @@ zip_safe = false tests = pytest pytest-asyncio - mypy>=0.800 + mypy>=0.920 [tool:pytest] testpaths = tests