diff --git a/examples/embedding/inprocess_terminal.py b/examples/embedding/inprocess_terminal.py index 79e11e03a..c951859e8 100644 --- a/examples/embedding/inprocess_terminal.py +++ b/examples/embedding/inprocess_terminal.py @@ -1,8 +1,7 @@ """An in-process terminal example.""" import os -import sys -import tornado +from anyio import run from jupyter_console.ptshell import ZMQTerminalInteractiveShell from ipykernel.inprocess.manager import InProcessKernelManager @@ -13,46 +12,15 @@ def print_process_id(): print("Process ID is:", os.getpid()) -def init_asyncio_patch(): - """set default asyncio policy to be compatible with tornado - Tornado 6 (at least) is not compatible with the default - asyncio implementation on Windows - Pick the older SelectorEventLoopPolicy on Windows - if the known-incompatible default policy is in use. - do this as early as possible to make it a low priority and overrideable - ref: https://github.com/tornadoweb/tornado/issues/2608 - FIXME: if/when tornado supports the defaults in asyncio, - remove and bump tornado requirement for py38 - """ - if ( - sys.platform.startswith("win") - and sys.version_info >= (3, 8) - and tornado.version_info < (6, 1) - ): - import asyncio - - try: - from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy - except ImportError: - pass - # not affected - else: - if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy: - # WindowsProactorEventLoopPolicy is not compatible with tornado 6 - # fallback to the pre-3.8 default of Selector - asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy()) - - -def main(): +async def main(): """The main function.""" print_process_id() # Create an in-process kernel # >>> print_process_id() # will print the same process ID as the main process - init_asyncio_patch() kernel_manager = InProcessKernelManager() - kernel_manager.start_kernel() + await kernel_manager.start_kernel() kernel = kernel_manager.kernel kernel.gui = "qt4" kernel.shell.push({"foo": 43, "print_process_id": print_process_id}) @@ -64,4 +32,4 @@ def main(): if __name__ == "__main__": - main() + run(main) diff --git a/ipykernel/inprocess/blocking.py b/ipykernel/inprocess/blocking.py index f09bb2316..3ec0c439f 100644 --- a/ipykernel/inprocess/blocking.py +++ b/ipykernel/inprocess/blocking.py @@ -80,10 +80,10 @@ class BlockingInProcessKernelClient(InProcessKernelClient): iopub_channel_class = Type(BlockingInProcessChannel) stdin_channel_class = Type(BlockingInProcessStdInChannel) - def wait_for_ready(self): + async def wait_for_ready(self): """Wait for kernel info reply on shell channel.""" while True: - self.kernel_info() + await self.kernel_info() try: msg = self.shell_channel.get_msg(block=True, timeout=1) except Empty: @@ -103,6 +103,5 @@ def wait_for_ready(self): while True: try: msg = self.iopub_channel.get_msg(block=True, timeout=0.2) - print(msg["msg_type"]) except Empty: break diff --git a/ipykernel/inprocess/client.py b/ipykernel/inprocess/client.py index ea964ecde..54e7d45aa 100644 --- a/ipykernel/inprocess/client.py +++ b/ipykernel/inprocess/client.py @@ -11,11 +11,9 @@ # Imports # ----------------------------------------------------------------------------- -import asyncio from jupyter_client.client import KernelClient from jupyter_client.clientabc import KernelClientABC -from jupyter_core.utils import run_sync # IPython imports from traitlets import Instance, Type, default @@ -101,7 +99,7 @@ def hb_channel(self): # Methods for sending specific messages # ------------------------------------- - def execute( + async def execute( self, code, silent=False, store_history=True, user_expressions=None, allow_stdin=None ): """Execute code on the client.""" @@ -115,19 +113,19 @@ def execute( allow_stdin=allow_stdin, ) msg = self.session.msg("execute_request", content) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] - def complete(self, code, cursor_pos=None): + async def complete(self, code, cursor_pos=None): """Get code completion.""" if cursor_pos is None: cursor_pos = len(code) content = dict(code=code, cursor_pos=cursor_pos) msg = self.session.msg("complete_request", content) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] - def inspect(self, code, cursor_pos=None, detail_level=0): + async def inspect(self, code, cursor_pos=None, detail_level=0): """Get code inspection.""" if cursor_pos is None: cursor_pos = len(code) @@ -137,14 +135,14 @@ def inspect(self, code, cursor_pos=None, detail_level=0): detail_level=detail_level, ) msg = self.session.msg("inspect_request", content) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] - def history(self, raw=True, output=False, hist_access_type="range", **kwds): + async def history(self, raw=True, output=False, hist_access_type="range", **kwds): """Get code history.""" content = dict(raw=raw, output=output, hist_access_type=hist_access_type, **kwds) msg = self.session.msg("history_request", content) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] def shutdown(self, restart=False): @@ -153,17 +151,17 @@ def shutdown(self, restart=False): msg = "Cannot shutdown in-process kernel" raise NotImplementedError(msg) - def kernel_info(self): + async def kernel_info(self): """Request kernel info.""" msg = self.session.msg("kernel_info_request") - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] - def comm_info(self, target_name=None): + async def comm_info(self, target_name=None): """Request a dictionary of valid comms and their targets.""" content = {} if target_name is None else dict(target_name=target_name) msg = self.session.msg("comm_info_request", content) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] def input(self, string): @@ -173,29 +171,21 @@ def input(self, string): raise RuntimeError(msg) self.kernel.raw_input_str = string - def is_complete(self, code): + async def is_complete(self, code): """Handle an is_complete request.""" msg = self.session.msg("is_complete_request", {"code": code}) - self._dispatch_to_kernel(msg) + await self._dispatch_to_kernel(msg) return msg["header"]["msg_id"] - def _dispatch_to_kernel(self, msg): + async def _dispatch_to_kernel(self, msg): """Send a message to the kernel and handle a reply.""" kernel = self.kernel if kernel is None: - msg = "Cannot send request. No kernel exists." - raise RuntimeError(msg) + error_message = "Cannot send request. No kernel exists." + raise RuntimeError(error_message) - stream = kernel.shell_stream - self.session.send(stream, msg) - msg_parts = stream.recv_multipart() - if run_sync is not None: - dispatch_shell = run_sync(kernel.dispatch_shell) - dispatch_shell(msg_parts) - else: - loop = asyncio.get_event_loop() - loop.run_until_complete(kernel.dispatch_shell(msg_parts)) - idents, reply_msg = self.session.recv(stream, copy=False) + kernel.shell_socket.put(msg) + reply_msg = await kernel.shell_socket.get() self.shell_channel.call_handlers_later(reply_msg) def get_shell_msg(self, block=True, timeout=None): diff --git a/ipykernel/inprocess/ipkernel.py b/ipykernel/inprocess/ipkernel.py index df34303b4..b31091c2a 100644 --- a/ipykernel/inprocess/ipkernel.py +++ b/ipykernel/inprocess/ipkernel.py @@ -7,6 +7,8 @@ import sys from contextlib import contextmanager +from anyio import TASK_STATUS_IGNORED +from anyio.abc import TaskStatus from IPython.core.interactiveshell import InteractiveShellABC from traitlets import Any, Enum, Instance, List, Type, default @@ -48,10 +50,10 @@ class InProcessKernel(IPythonKernel): # ------------------------------------------------------------------------- shell_class = Type(allow_none=True) - _underlying_iopub_socket = Instance(DummySocket, ()) + _underlying_iopub_socket = Instance(DummySocket, (False,)) iopub_thread: IOPubThread = Instance(IOPubThread) # type:ignore[assignment] - shell_stream = Instance(DummySocket, ()) + shell_socket = Instance(DummySocket, (True,)) @default("iopub_thread") def _default_iopub_thread(self): @@ -65,13 +67,13 @@ def _default_iopub_thread(self): def _default_iopub_socket(self): return self.iopub_thread.background_socket - stdin_socket = Instance(DummySocket, ()) # type:ignore[assignment] + stdin_socket = Instance(DummySocket, (False,)) # type:ignore[assignment] def __init__(self, **traits): """Initialize the kernel.""" super().__init__(**traits) - self._underlying_iopub_socket.observe(self._io_dispatch, names=["message_sent"]) + self._io_dispatch() self.shell.kernel = self async def execute_request(self, stream, ident, parent): @@ -79,9 +81,13 @@ async def execute_request(self, stream, ident, parent): with self._redirected_io(): await super().execute_request(stream, ident, parent) - def start(self): + async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: """Override registration of dispatchers for streams.""" self.shell.exit_now = False + await super().start(task_status=task_status) + + def stop(self): + super().stop() def _abort_queues(self): """The in-process kernel doesn't abort requests.""" @@ -127,12 +133,15 @@ def _redirected_io(self): # ------ Trait change handlers -------------------------------------------- - def _io_dispatch(self, change): + def _io_dispatch(self): """Called when a message is sent to the IO socket.""" assert self.iopub_socket.io_thread is not None - ident, msg = self.session.recv(self.iopub_socket.io_thread.socket, copy=False) - for frontend in self.frontends: - frontend.iopub_channel.call_handlers(msg) + + def callback(msg): + for frontend in self.frontends: + frontend.iopub_channel.call_handlers(msg) + + self.iopub_thread.socket.on_recv = callback # ------ Trait initializers ----------------------------------------------- @@ -142,7 +151,7 @@ def _default_log(self): @default("session") def _default_session(self): - from jupyter_client.session import Session + from .session import Session return Session(parent=self, key=INPROCESS_KEY) diff --git a/ipykernel/inprocess/manager.py b/ipykernel/inprocess/manager.py index 3388dbf61..2ae4b13aa 100644 --- a/ipykernel/inprocess/manager.py +++ b/ipykernel/inprocess/manager.py @@ -3,12 +3,14 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from anyio import TASK_STATUS_IGNORED +from anyio.abc import TaskStatus from jupyter_client.manager import KernelManager from jupyter_client.managerabc import KernelManagerABC -from jupyter_client.session import Session from traitlets import DottedObjectName, Instance, default from .constants import INPROCESS_KEY +from .session import Session class InProcessKernelManager(KernelManager): @@ -41,27 +43,31 @@ def _default_session(self): # Kernel management methods # -------------------------------------------------------------------------- - def start_kernel(self, **kwds): + async def start_kernel(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED, **kwds) -> None: """Start the kernel.""" from ipykernel.inprocess.ipkernel import InProcessKernel self.kernel = InProcessKernel(parent=self, session=self.session) + await self.kernel.start(task_status=task_status) def shutdown_kernel(self): """Shutdown the kernel.""" self.kernel.iopub_thread.stop() self._kill_kernel() - def restart_kernel(self, now=False, **kwds): + async def restart_kernel( + self, now=False, *, task_status: TaskStatus = TASK_STATUS_IGNORED, **kwds + ) -> None: """Restart the kernel.""" self.shutdown_kernel() - self.start_kernel(**kwds) + await self.start_kernel(task_status=task_status, **kwds) @property def has_kernel(self): return self.kernel is not None def _kill_kernel(self): + self.kernel.stop() self.kernel = None def interrupt_kernel(self): diff --git a/ipykernel/inprocess/session.py b/ipykernel/inprocess/session.py new file mode 100644 index 000000000..0eaed2c60 --- /dev/null +++ b/ipykernel/inprocess/session.py @@ -0,0 +1,41 @@ +from jupyter_client.session import Session as _Session + + +class Session(_Session): + async def recv(self, socket, copy=True): + return await socket.recv_multipart() + + def send( + self, + socket, + msg_or_type, + content=None, + parent=None, + ident=None, + buffers=None, + track=False, + header=None, + metadata=None, + ): + if isinstance(msg_or_type, str): + msg = self.msg( + msg_or_type, + content=content, + parent=parent, + header=header, + metadata=metadata, + ) + else: + # We got a Message or message dict, not a msg_type so don't + # build a new Message. + msg = msg_or_type + buffers = buffers or msg.get("buffers", []) + + socket.send_multipart(msg) + return msg + + def feed_identities(self, msg, copy=True): + return "", msg + + def deserialize(self, msg, content=True, copy=True): + return msg diff --git a/ipykernel/inprocess/socket.py b/ipykernel/inprocess/socket.py index 7e48789e9..3e79297c0 100644 --- a/ipykernel/inprocess/socket.py +++ b/ipykernel/inprocess/socket.py @@ -3,10 +3,11 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. -from queue import Queue +from math import inf import zmq -from traitlets import HasTraits, Instance, Int +from anyio import create_memory_object_stream +from traitlets import HasTraits, Instance # ----------------------------------------------------------------------------- # Dummy socket class @@ -14,29 +15,53 @@ class DummySocket(HasTraits): - """A dummy socket implementing (part of) the zmq.Socket interface.""" + """A dummy socket implementing (part of) the zmq.asyncio.Socket interface.""" - queue = Instance(Queue, ()) - message_sent = Int(0) # Should be an Event - context = Instance(zmq.Context) + context = Instance(zmq.asyncio.Context) def _context_default(self): - return zmq.Context() + return zmq.asyncio.Context() # ------------------------------------------------------------------------- # Socket interface # ------------------------------------------------------------------------- - def recv_multipart(self, flags=0, copy=True, track=False): + def __init__(self, is_shell, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_shell = is_shell + self.on_recv = None + if is_shell: + self.in_send_stream, self.in_receive_stream = create_memory_object_stream( + max_buffer_size=inf + ) + self.out_send_stream, self.out_receive_stream = create_memory_object_stream( + max_buffer_size=inf + ) + + def put(self, msg): + self.in_send_stream.send_nowait(msg) + + async def get(self): + msg = await self.out_receive_stream.receive() + return msg + + async def recv_multipart(self, flags=0, copy=True, track=False): """Recv a multipart message.""" - return self.queue.get_nowait() + msg = await self.in_receive_stream.receive() + return msg def send_multipart(self, msg_parts, flags=0, copy=True, track=False): """Send a multipart message.""" - msg_parts = list(map(zmq.Message, msg_parts)) - self.queue.put_nowait(msg_parts) - self.message_sent += 1 + if self.is_shell: + self.out_send_stream.send_nowait(msg_parts) + if self.on_recv is not None: + self.on_recv(msg_parts) def flush(self, timeout=1.0): """no-op to comply with stream API""" pass + + async def poll(self, timeout=0): + assert timeout == 0 + statistics = self.in_receive_stream.statistics() + return statistics.current_buffer_used != 0 diff --git a/ipykernel/inprocess/tests/test_kernel.py b/ipykernel/inprocess/tests/test_kernel.py index 0c806e0cc..0d305ffeb 100644 --- a/ipykernel/inprocess/tests/test_kernel.py +++ b/ipykernel/inprocess/tests/test_kernel.py @@ -6,6 +6,7 @@ from io import StringIO import pytest +from anyio import create_task_group from IPython.utils.io import capture_output # type:ignore[attr-defined] from jupyter_client.session import Session @@ -37,57 +38,58 @@ def patch_cell_id(): Session.msg = orig_msg # type:ignore -@pytest.fixture() -def kc(): - km = InProcessKernelManager() - km.start_kernel() - kc = km.client() - kc.start_channels() - kc.wait_for_ready() - yield kc +@pytest.fixture +def anyio_backend(): + return "asyncio" -@pytest.mark.skip("FIXME") -def test_with_cell_id(kc): +@pytest.fixture() +async def kc(anyio_backend): + async with create_task_group() as tg: + km = InProcessKernelManager() + await tg.start(km.start_kernel) + kc = km.client() + kc.start_channels() + await kc.wait_for_ready() + yield kc + km.shutdown_kernel() + + +async def test_with_cell_id(kc): with patch_cell_id(): - kc.execute("1+1") + await kc.execute("1+1") -@pytest.mark.skip("FIXME") -def test_pylab(kc): +async def test_pylab(kc): """Does %pylab work in the in-process kernel?""" _ = pytest.importorskip("matplotlib", reason="This test requires matplotlib") - kc.execute("%pylab") + await kc.execute("%pylab") out, err = assemble_output(kc.get_iopub_msg) assert "matplotlib" in out -@pytest.mark.skip("FIXME") -def test_raw_input(kc): +async def test_raw_input(kc): """Does the in-process kernel handle raw_input correctly?""" io = StringIO("foobar\n") sys_stdin = sys.stdin sys.stdin = io try: - kc.execute("x = input()") + await kc.execute("x = input()") finally: sys.stdin = sys_stdin assert kc.kernel.shell.user_ns.get("x") == "foobar" @pytest.mark.skipif("__pypy__" in sys.builtin_module_names, reason="fails on pypy") -@pytest.mark.skip("FIXME") -def test_stdout(kc): +async def test_stdout(kc): """Does the in-process kernel correctly capture IO?""" - kernel = InProcessKernel() + kernel = kc.kernel with capture_output() as io: kernel.shell.run_cell('print("foo")') assert io.stdout == "foo\n" - kc = BlockingInProcessKernelClient(kernel=kernel, session=kernel.session) - kernel.frontends.append(kc) - kc.execute('print("bar")') + await kc.execute('print("bar")') out, err = assemble_output(kc.get_iopub_msg) assert out == "bar\n" @@ -105,11 +107,10 @@ def test_capfd(kc): kernel.frontends.append(kc) kc.execute("import os") kc.execute('os.system("echo capfd")') - out, err = assemble_output(kc.iopub_channel) + out, err = assemble_output(kc.get_iopub_msg) assert out == "capfd\n" -@pytest.mark.skip("FIXME") def test_getpass_stream(kc): """Tests that kernel getpass accept the stream parameter""" kernel = InProcessKernel() @@ -119,7 +120,6 @@ def test_getpass_stream(kc): kernel.getpass(stream="non empty") -@pytest.mark.skip("FIXME") async def test_do_execute(kc): kernel = InProcessKernel() await kernel.do_execute("a=1", True) diff --git a/ipykernel/inprocess/tests/test_kernelmanager.py b/ipykernel/inprocess/tests/test_kernelmanager.py index e000299f6..a76d5d45f 100644 --- a/ipykernel/inprocess/tests/test_kernelmanager.py +++ b/ipykernel/inprocess/tests/test_kernelmanager.py @@ -4,6 +4,7 @@ import unittest import pytest +from anyio import create_task_group from ipykernel.inprocess.manager import InProcessKernelManager @@ -12,92 +13,100 @@ # ----------------------------------------------------------------------------- -@pytest.mark.skip("FIXME") -class InProcessKernelManagerTestCase(unittest.TestCase): - def setUp(self): - self.km = InProcessKernelManager() +@pytest.fixture +def anyio_backend(): + return "asyncio" - def tearDown(self): - if self.km.has_kernel: - self.km.shutdown_kernel() - def test_interface(self): +@pytest.fixture() +async def km_kc(anyio_backend): + async with create_task_group() as tg: + km = InProcessKernelManager() + await tg.start(km.start_kernel) + kc = km.client() + kc.start_channels() + await kc.wait_for_ready() + yield km, kc + km.shutdown_kernel() + + +@pytest.fixture() +async def km(anyio_backend): + km = InProcessKernelManager() + yield km + if km.has_kernel: + km.shutdown_kernel() + + +class TestInProcessKernelManager: + async def test_interface(self, km): """Does the in-process kernel manager implement the basic KM interface?""" - km = self.km - assert not km.has_kernel + async with create_task_group() as tg: + assert not km.has_kernel - km.start_kernel() - assert km.has_kernel - assert km.kernel is not None + await tg.start(km.start_kernel) + assert km.has_kernel + assert km.kernel is not None - kc = km.client() - assert not kc.channels_running + kc = km.client() + assert not kc.channels_running - kc.start_channels() - assert kc.channels_running + kc.start_channels() + assert kc.channels_running - old_kernel = km.kernel - km.restart_kernel() - self.assertIsNotNone(km.kernel) - assert km.kernel != old_kernel + old_kernel = km.kernel + await tg.start(km.restart_kernel) + assert km.kernel is not None + assert km.kernel != old_kernel - km.shutdown_kernel() - assert not km.has_kernel + km.shutdown_kernel() + assert not km.has_kernel + + with pytest.raises(NotImplementedError): + km.interrupt_kernel() - self.assertRaises(NotImplementedError, km.interrupt_kernel) - self.assertRaises(NotImplementedError, km.signal_kernel, 9) + with pytest.raises(NotImplementedError): + km.signal_kernel(9) - kc.stop_channels() - assert not kc.channels_running + kc.stop_channels() + assert not kc.channels_running - def test_execute(self): + async def test_execute(self, km_kc): """Does executing code in an in-process kernel work?""" - km = self.km - km.start_kernel() - kc = km.client() - kc.start_channels() - kc.wait_for_ready() - kc.execute("foo = 1") + km, kc = km_kc + + await kc.execute("foo = 1") assert km.kernel.shell.user_ns["foo"] == 1 - def test_complete(self): + async def test_complete(self, km_kc): """Does requesting completion from an in-process kernel work?""" - km = self.km - km.start_kernel() - kc = km.client() - kc.start_channels() - kc.wait_for_ready() + km, kc = km_kc + km.kernel.shell.push({"my_bar": 0, "my_baz": 1}) - kc.complete("my_ba", 5) + await kc.complete("my_ba", 5) msg = kc.get_shell_msg() assert msg["header"]["msg_type"] == "complete_reply" - self.assertEqual(sorted(msg["content"]["matches"]), ["my_bar", "my_baz"]) + assert sorted(msg["content"]["matches"]) == ["my_bar", "my_baz"] - def test_inspect(self): + async def test_inspect(self, km_kc): """Does requesting object information from an in-process kernel work?""" - km = self.km - km.start_kernel() - kc = km.client() - kc.start_channels() - kc.wait_for_ready() + km, kc = km_kc + km.kernel.shell.user_ns["foo"] = 1 - kc.inspect("foo") + await kc.inspect("foo") msg = kc.get_shell_msg() assert msg["header"]["msg_type"] == "inspect_reply" content = msg["content"] assert content["found"] text = content["data"]["text/plain"] - self.assertIn("int", text) + assert "int" in text - def test_history(self): + async def test_history(self, km_kc): """Does requesting history from an in-process kernel work?""" - km = self.km - km.start_kernel() - kc = km.client() - kc.start_channels() - kc.wait_for_ready() - kc.execute("1") - kc.history(hist_access_type="tail", n=1) + km, kc = km_kc + + await kc.execute("1") + await kc.history(hist_access_type="tail", n=1) msg = kc.shell_channel.get_msgs()[-1] assert msg["header"]["msg_type"] == "history_reply" history = msg["content"]["history"] diff --git a/ipykernel/ipkernel.py b/ipykernel/ipkernel.py index 62f87b419..f9809109c 100644 --- a/ipykernel/ipkernel.py +++ b/ipykernel/ipkernel.py @@ -6,10 +6,12 @@ import sys import threading import typing as t +from dataclasses import dataclass import comm import zmq.asyncio -from anyio import create_task_group, to_thread +from anyio import TASK_STATUS_IGNORED, create_task_group, to_thread +from anyio.abc import TaskStatus from IPython.core import release from IPython.utils.tokenutil import line_at_cursor, token_at_cursor from traitlets import Any, Bool, HasTraits, Instance, List, Type, observe, observe_compat @@ -194,9 +196,6 @@ def __init__(self, **kwargs): } async def process_debugpy(self): - if not _is_debugpy_available: - return - async with create_task_group() as tg: tg.start_soon(self.receive_debugpy_messages) tg.start_soon(self.poll_stopped_queue) @@ -204,11 +203,19 @@ async def process_debugpy(self): tg.cancel_scope.cancel() async def receive_debugpy_messages(self): + if not _is_debugpy_available: + return + while True: await self.receive_debugpy_message() async def receive_debugpy_message(self, msg=None): - msg = msg or await self.debugpy_socket.recv_multipart() + if not _is_debugpy_available: + return + + if msg is None: + assert self.debugpy_socket is not None + msg = await self.debugpy_socket.recv_multipart() # The first frame is the socket id, we can drop it frame = msg[1].decode("utf-8") self.log.debug("Debugpy received: %s", frame) @@ -223,16 +230,15 @@ async def poll_stopped_queue(self): while True: await self.debugger.handle_stopped_event() - async def start(self): + async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: """Start the kernel.""" self.shell.exit_now = False - control_tasks = [] if self.debugpy_socket is None: self.log.warning("debugpy_socket undefined, debugging will not be enabled") else: self.debugpy_stop = threading.Event() - control_tasks.append(self.process_debugpy) - await super().start(control_tasks=control_tasks) + self.control_tasks.append(self.process_debugpy) + await super().start(task_status=task_status) def stop(self): super().stop() @@ -363,26 +369,28 @@ async def run_cell(*args, **kwargs): ) coro = run_cell(code, **kwargs) - shell_result = None + @dataclass + class Execution: + interrupt: bool = False + result: t.Any = None - async def run(): - nonlocal interrupt, shell_result - shell_result = await coro - if not interrupt: + async def run(execution: Execution) -> None: + execution.result = await coro + if not execution.interrupt: self.shell_interrupt.put(False) res = None try: async with create_task_group() as tg: - interrupt = False + execution = Execution() self.shell_is_awaiting = True - tg.start_soon(run) - interrupt = await to_thread.run_sync(self.shell_interrupt.get) + tg.start_soon(run, execution) + execution.interrupt = await to_thread.run_sync(self.shell_interrupt.get) self.shell_is_awaiting = False - if interrupt: + if execution.interrupt: tg.cancel_scope.cancel() - res = shell_result + res = execution.result finally: shell.events.trigger("post_execute") if not silent: diff --git a/ipykernel/kernelbase.py b/ipykernel/kernelbase.py index 2cd07d56d..6c438c66c 100644 --- a/ipykernel/kernelbase.py +++ b/ipykernel/kernelbase.py @@ -16,7 +16,6 @@ import uuid import warnings from datetime import datetime -from functools import partial from signal import SIGINT, SIGTERM, Signals if sys.platform != "win32": @@ -34,7 +33,8 @@ import psutil import zmq -from anyio import create_task_group, sleep, to_thread +from anyio import TASK_STATUS_IGNORED, create_task_group, sleep, to_thread +from anyio.abc import TaskStatus from IPython.core.error import StdinNotImplementedError from jupyter_client.session import Session from traitlets.config.configurable import SingletonConfigurable @@ -67,6 +67,8 @@ def _accepts_cell_id(meth): class Kernel(SingletonConfigurable): """The base kernel class.""" + _aborted_time: float + # --------------------------------------------------------------------------- # Kernel interface # --------------------------------------------------------------------------- @@ -87,6 +89,7 @@ class Kernel(SingletonConfigurable): _is_test = Bool(False) control_socket = Instance(zmq.asyncio.Socket, allow_none=True) + control_tasks = List() debug_shell_socket = Any() @@ -375,11 +378,10 @@ async def process_shell_message(self, msg=None): sys.stderr.flush() self._publish_status("idle", "shell") - async def control_main(self, control_tasks): + async def control_main(self): async with create_task_group() as tg: - if control_tasks: - for task in control_tasks: - tg.start_soon(task) + for task in self.control_tasks: + tg.start_soon(task) tg.start_soon(self.process_control) await to_thread.run_sync(self.control_stop.wait) tg.cancel_scope.cancel() @@ -426,18 +428,18 @@ async def process_control_message(self, msg=None): sys.stderr.flush() self._publish_status("idle", "control") - async def start(self, control_tasks=None): + async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None: """Process messages on shell and control channels""" async with create_task_group() as tg: self.control_stop = threading.Event() if not self._is_test and self.control_socket is not None: if self.control_thread: - self.control_thread.set_task(partial(self.control_main, control_tasks)) + self.control_thread.set_task(self.control_main) self.control_thread.start() else: - tg.start_soon(self.control_main, control_tasks) + tg.start_soon(self.control_main) - self.shell_interrupt = queue.Queue() + self.shell_interrupt: queue.Queue[bool] = queue.Queue() self.shell_is_awaiting = False self.shell_is_blocking = False self.shell_stop = threading.Event() @@ -447,6 +449,8 @@ async def start(self, control_tasks=None): # publish idle status self._publish_status("starting", "shell") + task_status.started() + def stop(self): self.shell_stop.set() self.control_stop.set() @@ -949,8 +953,6 @@ async def abort_request(self, socket, ident, parent): # pragma: no cover msg_ids = parent["content"].get("msg_ids", None) if isinstance(msg_ids, str): msg_ids = [msg_ids] - if not msg_ids: - await self._abort_queues() for mid in msg_ids: self.aborted.add(str(mid)) diff --git a/ipykernel/tests/conftest.py b/ipykernel/tests/conftest.py index 1cd0560a2..2fcd83c0f 100644 --- a/ipykernel/tests/conftest.py +++ b/ipykernel/tests/conftest.py @@ -2,13 +2,14 @@ import logging import os from math import inf -from typing import no_type_check +from typing import Any, Callable, no_type_check from unittest.mock import MagicMock import pytest import zmq import zmq.asyncio from anyio import create_memory_object_stream, create_task_group +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from jupyter_client.session import Session from ipykernel.ipkernel import IPythonKernel @@ -24,7 +25,7 @@ @pytest.fixture def anyio_backend(): - return 'asyncio' + return "asyncio" pytestmark = pytest.mark.anyio @@ -52,7 +53,7 @@ def anyio_backend(): class TestSession(Session): """A session that copies sent messages to an internal stream, so that - they can be accessed later. + they can be accessed later. """ def __init__(self, sockets, *args, **kwargs): @@ -64,11 +65,16 @@ def __init__(self, sockets, *args, **kwargs): def send(self, socket, *args, **kwargs): msg = super().send(socket, *args, **kwargs) - self._streams[socket]["send"].send_nowait(msg) + send_stream: MemoryObjectSendStream[Any] = self._streams[socket]["send"] + send_stream.send_nowait(msg) return msg class KernelMixin: + shell_socket: zmq.asyncio.Socket + control_socket: zmq.asyncio.Socket + stop: Callable[[], None] + log = logging.getLogger() def _initialize(self): @@ -104,13 +110,19 @@ def destroy(self): async def test_shell_message(self, *args, **kwargs): msg_list = self._prep_msg(*args, **kwargs) await self.process_shell_message(msg_list) - return await self.session._streams[self.shell_socket]["receive"].receive() + receive_stream: MemoryObjectReceiveStream[Any] = self.session._streams[self.shell_socket][ + "receive" + ] + return await receive_stream.receive() @no_type_check async def test_control_message(self, *args, **kwargs): msg_list = self._prep_msg(*args, **kwargs) await self.process_control_message(msg_list) - return await self.session._streams[self.control_socket]["receive"].receive() + receive_stream: MemoryObjectReceiveStream[Any] = self.session._streams[self.control_socket][ + "receive" + ] + return await receive_stream.receive() def _on_send(self, msg, *args, **kwargs): self._reply = msg diff --git a/ipykernel/tests/test_async.py b/ipykernel/tests/test_async.py index 9276fbebb..069d581a3 100644 --- a/ipykernel/tests/test_async.py +++ b/ipykernel/tests/test_async.py @@ -1,6 +1,7 @@ """Test async/await integration""" import time + import pytest from .test_message_spec import validate_message diff --git a/ipykernel/tests/test_ipkernel_direct.py b/ipykernel/tests/test_ipkernel_direct.py index 1febd7e81..dd6b5aa16 100644 --- a/ipykernel/tests/test_ipkernel_direct.py +++ b/ipykernel/tests/test_ipkernel_direct.py @@ -1,5 +1,6 @@ """Test IPythonKernel directly""" +import asyncio import os import pytest @@ -136,8 +137,8 @@ async def test_direct_clear(ipkernel): @pytest.mark.skip("ipykernel._cancel_on_sigint doesn't exist anymore") async def test_cancel_on_sigint(ipkernel: IPythonKernel) -> None: future: asyncio.Future = asyncio.Future() - with ipkernel._cancel_on_sigint(future): - pass + # with ipkernel._cancel_on_sigint(future): + # pass future.set_result(None) @@ -160,10 +161,10 @@ async def fake_poll_control_queue(): ipkernel.dispatch_queue = fake_dispatch_queue # type:ignore ipkernel.poll_control_queue = fake_poll_control_queue # type:ignore - ipkernel.start() - ipkernel.debugpy_stream = None - ipkernel.start() - await ipkernel.process_one(False) + await ipkernel.start() + ipkernel.debugpy_socket = None + await ipkernel.start() + # await ipkernel.process_one(False) await shell_future await control_future @@ -181,8 +182,8 @@ async def fake_poll_control_queue(): ipkernel.dispatch_queue = fake_dispatch_queue # type:ignore ipkernel.poll_control_queue = fake_poll_control_queue # type:ignore - ipkernel.debugpy_stream = None - ipkernel.start() + ipkernel.debugpy_socket = None + await ipkernel.start() await shell_future await control_future diff --git a/ipykernel/tests/test_kernel_direct.py b/ipykernel/tests/test_kernel_direct.py index 4ab6fa56a..c189c3364 100644 --- a/ipykernel/tests/test_kernel_direct.py +++ b/ipykernel/tests/test_kernel_direct.py @@ -5,6 +5,7 @@ import asyncio import os +import warnings import pytest diff --git a/ipykernel/tests/test_message_spec.py b/ipykernel/tests/test_message_spec.py index 56930d0e5..c160bfabf 100644 --- a/ipykernel/tests/test_message_spec.py +++ b/ipykernel/tests/test_message_spec.py @@ -363,7 +363,6 @@ def test_execute_stop_on_error(): KC.execute(code='print("Hello")') KC.execute(code='print("world")') reply = KC.get_shell_msg(timeout=TIMEOUT) - print(reply) reply = KC.get_shell_msg(timeout=TIMEOUT) assert reply["content"]["status"] == "aborted" # second message, too diff --git a/ipykernel/tests/utils.py b/ipykernel/tests/utils.py index b1b4119f0..b435e19dd 100644 --- a/ipykernel/tests/utils.py +++ b/ipykernel/tests/utils.py @@ -46,7 +46,8 @@ def flush_channels(kc=None): for get_msg in (kc.get_shell_msg, kc.get_iopub_msg): while True: try: - msg = get_msg(timeout=0.1) + msg = get_msg(timeout=0.2) + print(f"{msg=}") except Empty: break else: @@ -76,8 +77,10 @@ def execute(code="", kc=None, **kwargs): kc = KC msg_id = kc.execute(code=code, **kwargs) reply = get_reply(kc, msg_id, TIMEOUT) + print(f"{reply=}") validate_message(reply, "execute_reply", msg_id) busy = kc.get_iopub_msg(timeout=TIMEOUT) + print(f"{busy=}") validate_message(busy, "status", msg_id) assert busy["content"]["execution_state"] == "busy" diff --git a/pyproject.toml b/pyproject.toml index 66dd8d760..bbec772df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,8 @@ test = [ "flaky", "ipyparallel", "pre-commit", - "pytest-timeout" + "pytest-timeout", + "trio", ] cov = [ "coverage[toml]",