From ec7525cf40e192ac1a9b596e1aaff4a25e6c0689 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 23 Jan 2024 16:47:47 -0500
Subject: [PATCH] feat(python/adbc_driver_manager): handle KeyboardInterrupt
Fixes #1484.
---
.github/workflows/native-unix.yml | 2 +
docker-compose.yml | 6 +
.../tests/test_errors.py | 48 +++++
python/adbc_driver_manager/MANIFEST.in | 2 +
.../adbc_driver_manager/_blocking_impl.cc | 182 ++++++++++++++++++
.../adbc_driver_manager/_blocking_impl.h | 38 ++++
.../adbc_driver_manager/_lib.pyi | 11 ++
.../adbc_driver_manager/_lib.pyx | 62 ++++++
.../adbc_driver_manager/dbapi.py | 17 +-
python/adbc_driver_manager/pyproject.toml | 1 +
python/adbc_driver_manager/setup.py | 1 +
11 files changed, 365 insertions(+), 5 deletions(-)
create mode 100644 python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc
create mode 100644 python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h
diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml
index 142d5f4be7..a24b580b23 100644
--- a/.github/workflows/native-unix.yml
+++ b/.github/workflows/native-unix.yml
@@ -477,6 +477,8 @@ jobs:
- name: Test Python Driver Flight SQL
shell: bash -l {0}
run: |
+ docker compose up -d --wait flightsql-test
+ export ADBC_TEST_FLIGHTSQL_URI=grpc://localhost:41414
env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local"
- name: Build Python Driver PostgreSQL
shell: bash -l {0}
diff --git a/docker-compose.yml b/docker-compose.yml
index 89394e598f..789d5d450d 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -150,6 +150,12 @@ services:
dockerfile: ci/docker/flightsql-test.dockerfile
args:
GO: ${GO}
+ healthcheck:
+ test: ["CMD", "curl", "--http2-prior-knowledge", "-XPOST", "-H", "content-type: application/grpc"]
+ interval: 5s
+ timeout: 30s
+ retries: 3
+ start_period: 5m
ports:
- "41414:41414"
volumes:
diff --git a/python/adbc_driver_flightsql/tests/test_errors.py b/python/adbc_driver_flightsql/tests/test_errors.py
index ed44b6a3fa..843a4a82b6 100644
--- a/python/adbc_driver_flightsql/tests/test_errors.py
+++ b/python/adbc_driver_flightsql/tests/test_errors.py
@@ -15,7 +15,11 @@
# specific language governing permissions and limitations
# under the License.
+import os
import re
+import signal
+import threading
+import time
import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
@@ -45,6 +49,50 @@ def test_query_cancel(test_dbapi):
cur.fetchone()
+def test_query_cancel_async(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ cur.execute("forever")
+
+ def _cancel():
+ time.sleep(2)
+ cur.adbc_cancel()
+
+ t = threading.Thread(target=_cancel, daemon=True)
+ t.start()
+
+ with pytest.raises(
+ test_dbapi.OperationalError,
+ match=re.escape("CANCELLED: [FlightSQL] context canceled"),
+ ):
+ cur.fetchone()
+
+
+def test_query_cancel_sigint(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ for _ in range(3):
+ cur.execute("forever")
+
+ # XXX: this handles the case of DoGet taking forever, but we also will
+ # want to test GetFlightInfo taking forever
+
+ def _cancel():
+ time.sleep(2)
+ os.kill(os.getpid(), signal.SIGINT)
+
+ t = threading.Thread(target=_cancel, daemon=True)
+ t.start()
+ with pytest.raises(
+ test_dbapi.OperationalError,
+ match=re.escape("CANCELLED: [FlightSQL] context canceled"),
+ ):
+ cur.fetchone()
+
+ # The cursor should still be usable
+ cur.execute("error_do_get")
+ with pytest.raises(test_dbapi.ProgrammingError):
+ cur.fetch_arrow_table()
+
+
def test_query_error_fetch(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("error_do_get")
diff --git a/python/adbc_driver_manager/MANIFEST.in b/python/adbc_driver_manager/MANIFEST.in
index 306c31144f..298ff3a9ca 100644
--- a/python/adbc_driver_manager/MANIFEST.in
+++ b/python/adbc_driver_manager/MANIFEST.in
@@ -22,6 +22,8 @@ include NOTICE.txt
include adbc_driver_manager/adbc.h
include adbc_driver_manager/adbc_driver_manager.cc
include adbc_driver_manager/adbc_driver_manager.h
+include adbc_driver_manager/_blocking_impl.cc
+include adbc_driver_manager/_blocking_impl.h
include adbc_driver_manager/_lib.pxd
include adbc_driver_manager/_lib.pyi
include adbc_driver_manager/_reader.pyi
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc
new file mode 100644
index 0000000000..ca29e89613
--- /dev/null
+++ b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc
@@ -0,0 +1,182 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "_blocking_impl.h"
+
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+namespace pyadbc_driver_manager {
+
+// This is somewhat derived from io_util.cc in arrow, but that implementation
+// isn't easily used outside of Arrow's monolith.
+namespace {
+static std::once_flag kInitOnce;
+// We may encounter errors below that we can't do anything about. Use this to
+// print out an error, once.
+static std::once_flag kWarnOnce;
+// This thread reads from a pipe forever. Whenver it reads something, it
+// calls the callback below.
+static std::thread kCancelThread;
+
+static std::mutex cancel_mutex;
+// This callback is registered by the Python side; basically it will call
+// cancel() on an ADBC object.
+static void (*cancel_callback)(void*) = nullptr;
+// Callback state (a pointer to the ADBC PyObject).
+static void* cancel_callback_data = nullptr;
+// A nonblocking self-pipe.
+static int pipe[2];
+// The old signal handler (most likely Python's).
+struct sigaction old_sigint;
+// Our signal handler (below).
+struct sigaction our_sigint;
+
+std::string MakePipe() {
+ int rc = 0;
+#if defined(__linux__) && defined(__GLIBC__)
+ rc = pipe2(pipe, O_CLOEXEC);
+#elif defined(_WIN32)
+ return "Unsupported platform";
+#else
+ rc = ::pipe(pipe);
+#endif
+
+ if (rc != 0) {
+ return std::strerror(errno);
+ }
+
+ // We may need to set FD_CLOEXEC on platforms without pipe2.
+#if (!defined(__linux__) || !defined(__GLIBC__)) && !defined(_WIN32)
+ {
+ int flags = fcntl(pipe[0], F_GETFD, 0);
+ if (flags < 0) {
+ return std::strerror(errno);
+ }
+ rc = fcntl(pipe[0], F_SETFD, flags | FD_CLOEXEC);
+ if (rc < 0) {
+ return std::strerror(errno);
+ }
+
+ flags = fcntl(pipe[1], F_GETFD, 0);
+ if (flags < 0) {
+ return std::strerror(errno);
+ }
+ rc = fcntl(pipe[1], F_SETFD, flags | FD_CLOEXEC);
+ if (rc < 0) {
+ return std::strerror(errno);
+ }
+ }
+#endif
+
+ // Make the write side nonblocking (the read side should stay blocking!)
+#if defined(_WIN32)
+ return "Unsupported platform";
+#else
+ {
+ int flags = fcntl(pipe[1], F_GETFL, 0);
+ if (flags < 0) {
+ return std::strerror(errno);
+ }
+ rc = fcntl(pipe[1], F_SETFL, flags | O_NONBLOCK);
+ if (rc < 0) {
+ return std::strerror(errno);
+ }
+ }
+#endif
+
+ return "";
+}
+
+void InterruptThread() {
+ while (true) {
+ char buf = 0;
+ ssize_t bytes_read = read(pipe[0], &buf, 1);
+ if (bytes_read < 0) {
+ if (errno == EINTR) continue;
+
+ // XXX: we failed reading from the pipe
+ std::string message = std::strerror(errno);
+ std::call_once(kWarnOnce, [&]() {
+ std::cerr << "adbc_driver_manager (native code): error handling interrupt: "
+ << message << std::endl;
+ });
+ } else if (bytes_read > 0) {
+ std::lock_guard lock(cancel_mutex);
+ if (cancel_callback != nullptr) {
+ cancel_callback(cancel_callback_data);
+ }
+ cancel_callback = nullptr;
+ cancel_callback_data = nullptr;
+ }
+ }
+}
+
+void SigintHandler(int) { (void)write(pipe[1], "X", 1); }
+
+} // namespace
+
+std::string InitBlockingCallback() {
+ std::string error;
+ std::call_once(kInitOnce, [&]() {
+ error = MakePipe();
+ if (!error.empty()) {
+ return;
+ }
+
+ our_sigint.sa_handler = &SigintHandler;
+ our_sigint.sa_flags = 0;
+ sigemptyset(&our_sigint.sa_mask);
+
+ kCancelThread = std::thread(InterruptThread);
+ kCancelThread.detach();
+ // TODO: set name of thread
+ });
+ return error;
+}
+
+std::string SetBlockingCallback(void (*callback)(void*), void* data) {
+ std::lock_guard lock(cancel_mutex);
+ cancel_callback = callback;
+ cancel_callback_data = data;
+
+ int rc = sigaction(SIGINT, &our_sigint, &old_sigint);
+ if (rc != 0) {
+ return std::strerror(errno);
+ }
+ return "";
+}
+
+std::string ClearBlockingCallback() {
+ std::lock_guard lock(cancel_mutex);
+ cancel_callback = nullptr;
+ cancel_callback_data = nullptr;
+
+ int rc = sigaction(SIGINT, &old_sigint, nullptr);
+ if (rc != 0) {
+ return std::strerror(errno);
+ }
+ return "";
+}
+
+} // namespace pyadbc_driver_manager
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h
new file mode 100644
index 0000000000..ac76252f3e
--- /dev/null
+++ b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h
@@ -0,0 +1,38 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Allow KeyboardInterrupt to function with ADBC in Python.
+///
+/// Call SetBlockingCallback to register a callback. This will temporarily
+/// suppress the Python SIGINT handler. When SIGINT is received, this module
+/// will handle it by calling the callback.
+
+#include
+
+namespace pyadbc_driver_manager {
+
+/// \brief Set up internal state to handle.
+/// \return An error message (or empty string).
+std::string InitBlockingCallback();
+/// \brief Set the callback for when SIGINT is received.
+/// \return An error message (or empty string).
+std::string SetBlockingCallback(void (*callback)(void*), void* data);
+/// \brief Clear the callback for when SIGINT is received.
+/// \return An error message (or empty string).
+std::string ClearBlockingCallback();
+
+} // namespace pyadbc_driver_manager
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
index 7afada9ecc..2ef4672029 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
@@ -17,6 +17,7 @@
# NOTE: generated with mypy's stubgen, then hand-edited to fix things
+import typing_extensions
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union
from typing import overload
@@ -201,3 +202,13 @@ def _test_error(
vendor_code: Optional[int],
sqlstate: Optional[str],
) -> Error: ...
+
+_P = typing_extensions.ParamSpec("_P")
+_T = typing.TypeVar("_T")
+
+def _blocking_call(
+ func: typing.Callable[_P, _T],
+ args: _P.args,
+ kwargs: _P.kwargs,
+ cancel: typing.Callable[[], None],
+) -> _T: ...
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 91139100bb..45129f870f 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -20,8 +20,10 @@
"""Low-level ADBC API."""
import enum
+import functools
import threading
import typing
+import warnings
from typing import List, Tuple
cimport cpython
@@ -33,6 +35,7 @@ from cpython.pycapsule cimport (
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
from libc.stdlib cimport malloc, free
from libc.string cimport memcpy, memset
+from libcpp.string cimport string as c_string
from libcpp.vector cimport vector as c_vector
if typing.TYPE_CHECKING:
@@ -1481,3 +1484,62 @@ cdef class AdbcStatement(_AdbcHandle):
cdef const CAdbcError* PyAdbcErrorFromArrayStream(
CArrowArrayStream* stream, CAdbcStatusCode* status):
return AdbcErrorFromArrayStream(stream, status)
+
+
+cdef extern from "_blocking_impl.h" nogil:
+ ctypedef void (*BlockingCallback)(void*) noexcept nogil
+ c_string CInitBlockingCallback"pyadbc_driver_manager::InitBlockingCallback"()
+ c_string CSetBlockingCallback"pyadbc_driver_manager::SetBlockingCallback"(BlockingCallback, void* data)
+ c_string CClearBlockingCallback"pyadbc_driver_manager::ClearBlockingCallback"()
+
+
+@functools.cache
+def _init_blocking_call():
+ error = bytes(CInitBlockingCallback()).decode("utf-8")
+ if error:
+ warnings.warn(
+ f"Failed to initialize KeyboardInterrupt support: {error}",
+ RuntimeWarning,
+ )
+
+
+def _blocking_call(func, args, kwargs, cancel):
+ """
+ Run functions that are expected to block with a native SIGINT handler.
+
+ Parameters
+ ----------
+ """
+ if threading.current_thread() is not threading.main_thread():
+ return func(*args, **kwargs)
+
+ _init_blocking_call()
+
+ # Set the callback for the background thread and save the signal handler
+ error = bytes(
+ CSetBlockingCallback(&_handle_blocking_call, cancel)
+ ).decode("utf-8")
+ if error:
+ warnings.warn(
+ f"Failed to set SIGINT handler: {error}",
+ RuntimeWarning,
+ )
+
+ try:
+ return func(*args, **kwargs)
+ finally:
+ # Restore the signal handler
+ error = bytes(CClearBlockingCallback()).decode("utf-8")
+ if error:
+ warnings.warn(
+ f"Failed to restore SIGINT handler: {error}",
+ RuntimeWarning,
+ )
+
+
+cdef void _handle_blocking_call(void* c_cancel) noexcept nogil:
+ # TODO: if this throws, we could save and restore the traceback later
+ # TODO: we could record that this was hit and raise a KeyboardInterrupt above
+ with gil:
+ cancel =