From b5ef1d726a41d8f019b4f3cdecd7358dafdc87cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 15 Nov 2021 12:06:12 -0500 Subject: [PATCH 1/3] =?UTF-8?q?=E2=9C=A8=20Add=20support=20for=20copying?= =?UTF-8?q?=20contextvars=20to=20and=20from=20threads=20in=20asyncio?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/anyio/_backends/_asyncio.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 834dd911..2dee8deb 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -1,6 +1,7 @@ import array import asyncio import concurrent.futures +import contextvars import math import socket import sys @@ -22,6 +23,8 @@ Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar, Union, cast) from weakref import WeakKeyDictionary +from sniffio import current_async_library_cvar + from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable from .._core._eventloop import claim_worker_thread, threadlocals @@ -800,17 +803,21 @@ async def run_sync_in_worker_thread( expired_worker = idle_workers.popleft() expired_worker.root_task.remove_done_callback(expired_worker.stop) expired_worker.stop() - - worker.queue.put_nowait((func, args, future)) + context = contextvars.copy_context() + context.run(current_async_library_cvar.set, None) + contextvars_aware_func = partial(context.run, func) + worker.queue.put_nowait((contextvars_aware_func, args, future)) return await future def run_sync_from_thread(func: Callable[..., T_Retval], *args: object, loop: Optional[asyncio.AbstractEventLoop] = None) -> T_Retval: + context = contextvars.copy_context() + context.run(current_async_library_cvar.set, "asyncio") @wraps(func) def wrapper() -> None: try: - f.set_result(func(*args)) + f.set_result(context.run(func, *args)) except BaseException as exc: f.set_exception(exc) if not isinstance(exc, Exception): @@ -825,8 +832,9 @@ def wrapper() -> None: def run_async_from_thread( func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object ) -> T_Retval: + context = contextvars.copy_context() f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( - func(*args), threadlocals.loop) + context.run(func, *args), threadlocals.loop) return f.result() From 1762e1b90cec68a3772696b830559253a2241019 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 15 Nov 2021 12:20:17 -0500 Subject: [PATCH 2/3] =?UTF-8?q?=E2=9C=85=20Add=20tests=20for=20copying=20c?= =?UTF-8?q?ontextvars=20in/from=20threads?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_to_thread.py | 308 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 308 insertions(+) diff --git a/tests/test_to_thread.py b/tests/test_to_thread.py index b17d230d..b4d5f372 100644 --- a/tests/test_to_thread.py +++ b/tests/test_to_thread.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import sys import threading import time @@ -6,6 +7,8 @@ from functools import partial from typing import Any, List, NoReturn, Optional +from sniffio import current_async_library_cvar, current_async_library + import pytest import anyio.to_thread @@ -235,3 +238,308 @@ async def taskfunc2() -> None: task1 = asyncio_event_loop.create_task(taskfunc1()) task2 = asyncio_event_loop.create_task(taskfunc2()) asyncio_event_loop.run_until_complete(asyncio.gather(task1, task2)) + + +anyio_test_contextvar = contextvars.ContextVar("anyio_test_contextvar") + + +async def test_to_thread_run_sync_contextvars(): + thread = threading.current_thread() + anyio_test_contextvar.set("main") + + def f(): + value = anyio_test_contextvar.get() + sniffio_cvar_value = current_async_library_cvar.get() + return (value, sniffio_cvar_value, threading.current_thread()) + + value, sniffio_cvar_value, child_thread = await to_thread.run_sync(f) + assert value == "main" + assert sniffio_cvar_value == None + assert child_thread != thread + + def g(): + parent_value = anyio_test_contextvar.get() + anyio_test_contextvar.set("worker") + inner_value = anyio_test_contextvar.get() + sniffio_cvar_value = current_async_library_cvar.get() + return ( + parent_value, + inner_value, + sniffio_cvar_value, + threading.current_thread(), + ) + + ( + parent_value, + inner_value, + sniffio_cvar_value, + child_thread, + ) = await to_thread.run_sync(g) + current_value = anyio_test_contextvar.get() + assert parent_value == "main" + assert inner_value == "worker" + assert ( + current_value == "main" + ), "The contextvar value set on the worker would not propagate back to the main thread" + assert sniffio_cvar_value is None + + +async def test_from_thread_run_sync_contextvars(): + anyio_test_contextvar.set("main") + + def thread_fn(): + thread_parent_value = anyio_test_contextvar.get() + anyio_test_contextvar.set("worker") + thread_current_value = anyio_test_contextvar.get() + sniffio_cvar_thread_pre_value = current_async_library_cvar.get() + + def back_in_main(): + back_parent_value = anyio_test_contextvar.get() + anyio_test_contextvar.set("back_in_main") + back_current_value = anyio_test_contextvar.get() + sniffio_cvar_back_value = current_async_library_cvar.get() + return back_parent_value, back_current_value, sniffio_cvar_back_value + + ( + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = from_thread.run_sync(back_in_main) + thread_after_value = anyio_test_contextvar.get() + sniffio_cvar_thread_after_value = current_async_library_cvar.get() + return ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) + + ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = await to_thread.run_sync(thread_fn) + current_value = anyio_test_contextvar.get() + sniffio_cvar_value = current_async_library_cvar.get() + assert current_value == thread_parent_value == "main" + assert thread_current_value == back_parent_value == thread_after_value == "worker" + assert back_current_value == "back_in_main" + assert sniffio_cvar_value == sniffio_cvar_back_value + assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None + + +async def test_from_thread_run_contextvars(): + anyio_test_contextvar.set("main") + + def thread_fn(): + thread_parent_value = anyio_test_contextvar.get() + anyio_test_contextvar.set("worker") + thread_current_value = anyio_test_contextvar.get() + sniffio_cvar_thread_pre_value = current_async_library_cvar.get() + + async def async_back_in_main(): + back_parent_value = anyio_test_contextvar.get() + anyio_test_contextvar.set("back_in_main") + back_current_value = anyio_test_contextvar.get() + sniffio_cvar_back_value = current_async_library() + return back_parent_value, back_current_value, sniffio_cvar_back_value + + ( + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = from_thread.run(async_back_in_main) + thread_after_value = anyio_test_contextvar.get() + sniffio_cvar_thread_after_value = current_async_library_cvar.get() + return ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) + + ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = await to_thread.run_sync(thread_fn) + current_value = anyio_test_contextvar.get() + current_sniffio_cvar_value = current_async_library_cvar.get() + assert current_value == thread_parent_value == "main" + assert thread_current_value == back_parent_value == thread_after_value == "worker" + assert back_current_value == "back_in_main" + assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None + assert sniffio_cvar_back_value == current_sniffio_cvar_value + + +def test_asyncio_to_thread_run_sync_contextvars(asyncio_event_loop: asyncio.AbstractEventLoop): + async def task(): + thread = threading.current_thread() + anyio_test_contextvar.set("main") + + def f(): + value = anyio_test_contextvar.get() + sniffio_cvar_value = current_async_library_cvar.get() + return (value, sniffio_cvar_value, threading.current_thread()) + + value, sniffio_cvar_value, child_thread = await to_thread.run_sync(f) + assert value == "main" + assert sniffio_cvar_value == None + assert child_thread != thread + + def g(): + parent_value = anyio_test_contextvar.get() + anyio_test_contextvar.set("worker") + inner_value = anyio_test_contextvar.get() + sniffio_cvar_value = current_async_library_cvar.get() + return ( + parent_value, + inner_value, + sniffio_cvar_value, + threading.current_thread(), + ) + + ( + parent_value, + inner_value, + sniffio_cvar_value, + child_thread, + ) = await to_thread.run_sync(g) + current_value = anyio_test_contextvar.get() + assert parent_value == "main" + assert inner_value == "worker" + assert ( + current_value == "main" + ), "The contextvar value set on the worker would not propagate back to the main thread" + assert sniffio_cvar_value is None + asyncio_event_loop.run_until_complete(task()) + + +def test_asyncio_from_thread_run_sync_contextvars(asyncio_event_loop: asyncio.AbstractEventLoop): + async def task(): + anyio_test_contextvar.set("main") + + def thread_fn(): + thread_parent_value = anyio_test_contextvar.get() + anyio_test_contextvar.set("worker") + thread_current_value = anyio_test_contextvar.get() + sniffio_cvar_thread_pre_value = current_async_library_cvar.get() + + def back_in_main(): + back_parent_value = anyio_test_contextvar.get() + anyio_test_contextvar.set("back_in_main") + back_current_value = anyio_test_contextvar.get() + sniffio_cvar_back_value = current_async_library_cvar.get() + return back_parent_value, back_current_value, sniffio_cvar_back_value + + ( + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = from_thread.run_sync(back_in_main) + thread_after_value = anyio_test_contextvar.get() + sniffio_cvar_thread_after_value = current_async_library_cvar.get() + return ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) + + ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = await to_thread.run_sync(thread_fn) + current_value = anyio_test_contextvar.get() + sniffio_cvar_value = current_async_library_cvar.get() + assert current_value == thread_parent_value == "main" + assert thread_current_value == back_parent_value == thread_after_value == "worker" + assert back_current_value == "back_in_main" + assert sniffio_cvar_value == sniffio_cvar_back_value + assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None + asyncio_event_loop.run_until_complete(task()) + + +def test_asyncio_from_thread_run_contextvars(asyncio_event_loop: asyncio.AbstractEventLoop): + async def task(): + anyio_test_contextvar.set("main") + + def thread_fn(): + thread_parent_value = anyio_test_contextvar.get() + anyio_test_contextvar.set("worker") + thread_current_value = anyio_test_contextvar.get() + sniffio_cvar_thread_pre_value = current_async_library_cvar.get() + + async def async_back_in_main(): + back_parent_value = anyio_test_contextvar.get() + anyio_test_contextvar.set("back_in_main") + back_current_value = anyio_test_contextvar.get() + # sniffio_cvar_back_value = current_async_library_cvar.get() + sniffio_cvar_back_value = current_async_library() + # raise RuntimeError(sniffio_cvar_back_value) + return back_parent_value, back_current_value, sniffio_cvar_back_value + + ( + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = from_thread.run(async_back_in_main) + thread_after_value = anyio_test_contextvar.get() + sniffio_cvar_thread_after_value = current_async_library_cvar.get() + return ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) + + ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = await to_thread.run_sync(thread_fn) + current_value = anyio_test_contextvar.get() + current_sniffio_cvar_value = current_async_library_cvar.get() + assert current_value == thread_parent_value == "main" + assert thread_current_value == back_parent_value == thread_after_value == "worker" + assert back_current_value == "back_in_main" + assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None + assert sniffio_cvar_back_value == current_sniffio_cvar_value + asyncio_event_loop.run_until_complete(task()) From 7364235a52844a1aa89159413b2fb23abc16f829 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 15 Nov 2021 12:29:25 -0500 Subject: [PATCH 3/3] =?UTF-8?q?=E2=9C=A8=20Add=20partial=20support=20for?= =?UTF-8?q?=20copying=20contextvars=20context=20to/from=20Trio?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/anyio/_backends/_trio.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index aa4e0e28..322ea612 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -1,4 +1,6 @@ import array +import contextvars +import functools import math import socket from concurrent.futures import Future @@ -11,6 +13,8 @@ Any, Awaitable, Callable, Collection, ContextManager, Coroutine, Deque, Dict, Generic, List, Mapping, NoReturn, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) +from sniffio import current_async_library_cvar + import trio.from_thread from outcome import Error, Outcome, Value from trio.socket import SocketType as TrioSocketType @@ -163,14 +167,22 @@ async def start(self, func: Callable[..., Coroutine], async def run_sync_in_worker_thread( func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional[trio.CapacityLimiter] = None) -> T_Retval: + context = contextvars.copy_context() + context.run(current_async_library_cvar.set, None) + def wrapper() -> T_Retval: with claim_worker_thread('trio'): - return func(*args) + return context.run(func, *args) return await run_sync(wrapper, cancellable=cancellable, limiter=limiter) run_async_from_thread = trio.from_thread.run -run_sync_from_thread = trio.from_thread.run_sync + + +def run_sync_from_thread(fn, *args, trio_token=None): + context = contextvars.copy_context() + context.run(current_async_library_cvar.set, "trio") + return trio.from_thread.run_sync(context.run, fn, *args, trio_token=trio_token) class BlockingPortal(abc.BlockingPortal):