Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add support for copying contextvars to thread workers #389

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import array
import asyncio
import concurrent.futures
import contextvars
import math
import socket
import sys
Expand All @@ -22,6 +23,8 @@
Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar, Union, cast)
from weakref import WeakKeyDictionary

from sniffio import current_async_library_cvar

from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable
from .._core._eventloop import claim_worker_thread, threadlocals
Expand Down Expand Up @@ -800,17 +803,21 @@ async def run_sync_in_worker_thread(
expired_worker = idle_workers.popleft()
expired_worker.root_task.remove_done_callback(expired_worker.stop)
expired_worker.stop()

worker.queue.put_nowait((func, args, future))
context = contextvars.copy_context()
context.run(current_async_library_cvar.set, None)
contextvars_aware_func = partial(context.run, func)
worker.queue.put_nowait((contextvars_aware_func, args, future))
return await future


def run_sync_from_thread(func: Callable[..., T_Retval], *args: object,
loop: Optional[asyncio.AbstractEventLoop] = None) -> T_Retval:
context = contextvars.copy_context()
context.run(current_async_library_cvar.set, "asyncio")
@wraps(func)
def wrapper() -> None:
try:
f.set_result(func(*args))
f.set_result(context.run(func, *args))
except BaseException as exc:
f.set_exception(exc)
if not isinstance(exc, Exception):
Expand All @@ -825,8 +832,9 @@ def wrapper() -> None:
def run_async_from_thread(
func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object
) -> T_Retval:
context = contextvars.copy_context()
f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe(
func(*args), threadlocals.loop)
context.run(func, *args), threadlocals.loop)
return f.result()


Expand Down
16 changes: 14 additions & 2 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import array
import contextvars
import functools
import math
import socket
from concurrent.futures import Future
Expand All @@ -11,6 +13,8 @@
Any, Awaitable, Callable, Collection, ContextManager, Coroutine, Deque, Dict, Generic, List,
Mapping, NoReturn, Optional, Sequence, Set, Tuple, Type, TypeVar, Union)

from sniffio import current_async_library_cvar

import trio.from_thread
from outcome import Error, Outcome, Value
from trio.socket import SocketType as TrioSocketType
Expand Down Expand Up @@ -163,14 +167,22 @@ async def start(self, func: Callable[..., Coroutine],
async def run_sync_in_worker_thread(
func: Callable[..., T_Retval], *args: object, cancellable: bool = False,
limiter: Optional[trio.CapacityLimiter] = None) -> T_Retval:
context = contextvars.copy_context()
context.run(current_async_library_cvar.set, None)

def wrapper() -> T_Retval:
with claim_worker_thread('trio'):
return func(*args)
return context.run(func, *args)

return await run_sync(wrapper, cancellable=cancellable, limiter=limiter)

run_async_from_thread = trio.from_thread.run
run_sync_from_thread = trio.from_thread.run_sync


def run_sync_from_thread(fn, *args, trio_token=None):
context = contextvars.copy_context()
context.run(current_async_library_cvar.set, "trio")
return trio.from_thread.run_sync(context.run, fn, *args, trio_token=trio_token)


class BlockingPortal(abc.BlockingPortal):
Expand Down
Loading