diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml index 142d5f4be7..a24b580b23 100644 --- a/.github/workflows/native-unix.yml +++ b/.github/workflows/native-unix.yml @@ -477,6 +477,8 @@ jobs: - name: Test Python Driver Flight SQL shell: bash -l {0} run: | + docker compose up -d --wait flightsql-test + export ADBC_TEST_FLIGHTSQL_URI=grpc://localhost:41414 env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" - name: Build Python Driver PostgreSQL shell: bash -l {0} diff --git a/docker-compose.yml b/docker-compose.yml index 89394e598f..789d5d450d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -150,6 +150,12 @@ services: dockerfile: ci/docker/flightsql-test.dockerfile args: GO: ${GO} + healthcheck: + test: ["CMD", "curl", "--http2-prior-knowledge", "-XPOST", "-H", "content-type: application/grpc"] + interval: 5s + timeout: 30s + retries: 3 + start_period: 5m ports: - "41414:41414" volumes: diff --git a/python/adbc_driver_flightsql/tests/test_errors.py b/python/adbc_driver_flightsql/tests/test_errors.py index ed44b6a3fa..843a4a82b6 100644 --- a/python/adbc_driver_flightsql/tests/test_errors.py +++ b/python/adbc_driver_flightsql/tests/test_errors.py @@ -15,7 +15,11 @@ # specific language governing permissions and limitations # under the License. +import os import re +import signal +import threading +import time import google.protobuf.any_pb2 as any_pb2 import google.protobuf.wrappers_pb2 as wrappers_pb2 @@ -45,6 +49,50 @@ def test_query_cancel(test_dbapi): cur.fetchone() +def test_query_cancel_async(test_dbapi): + with test_dbapi.cursor() as cur: + cur.execute("forever") + + def _cancel(): + time.sleep(2) + cur.adbc_cancel() + + t = threading.Thread(target=_cancel, daemon=True) + t.start() + + with pytest.raises( + test_dbapi.OperationalError, + match=re.escape("CANCELLED: [FlightSQL] context canceled"), + ): + cur.fetchone() + + +def test_query_cancel_sigint(test_dbapi): + with test_dbapi.cursor() as cur: + for _ in range(3): + cur.execute("forever") + + # XXX: this handles the case of DoGet taking forever, but we also will + # want to test GetFlightInfo taking forever + + def _cancel(): + time.sleep(2) + os.kill(os.getpid(), signal.SIGINT) + + t = threading.Thread(target=_cancel, daemon=True) + t.start() + with pytest.raises( + test_dbapi.OperationalError, + match=re.escape("CANCELLED: [FlightSQL] context canceled"), + ): + cur.fetchone() + + # The cursor should still be usable + cur.execute("error_do_get") + with pytest.raises(test_dbapi.ProgrammingError): + cur.fetch_arrow_table() + + def test_query_error_fetch(test_dbapi): with test_dbapi.cursor() as cur: cur.execute("error_do_get") diff --git a/python/adbc_driver_manager/MANIFEST.in b/python/adbc_driver_manager/MANIFEST.in index 306c31144f..298ff3a9ca 100644 --- a/python/adbc_driver_manager/MANIFEST.in +++ b/python/adbc_driver_manager/MANIFEST.in @@ -22,6 +22,8 @@ include NOTICE.txt include adbc_driver_manager/adbc.h include adbc_driver_manager/adbc_driver_manager.cc include adbc_driver_manager/adbc_driver_manager.h +include adbc_driver_manager/_blocking_impl.cc +include adbc_driver_manager/_blocking_impl.h include adbc_driver_manager/_lib.pxd include adbc_driver_manager/_lib.pyi include adbc_driver_manager/_reader.pyi diff --git a/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc new file mode 100644 index 0000000000..ca29e89613 --- /dev/null +++ b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "_blocking_impl.h" + +#include +#include + +#include +#include +#include +#include +#include + +namespace pyadbc_driver_manager { + +// This is somewhat derived from io_util.cc in arrow, but that implementation +// isn't easily used outside of Arrow's monolith. +namespace { +static std::once_flag kInitOnce; +// We may encounter errors below that we can't do anything about. Use this to +// print out an error, once. +static std::once_flag kWarnOnce; +// This thread reads from a pipe forever. Whenver it reads something, it +// calls the callback below. +static std::thread kCancelThread; + +static std::mutex cancel_mutex; +// This callback is registered by the Python side; basically it will call +// cancel() on an ADBC object. +static void (*cancel_callback)(void*) = nullptr; +// Callback state (a pointer to the ADBC PyObject). +static void* cancel_callback_data = nullptr; +// A nonblocking self-pipe. +static int pipe[2]; +// The old signal handler (most likely Python's). +struct sigaction old_sigint; +// Our signal handler (below). +struct sigaction our_sigint; + +std::string MakePipe() { + int rc = 0; +#if defined(__linux__) && defined(__GLIBC__) + rc = pipe2(pipe, O_CLOEXEC); +#elif defined(_WIN32) + return "Unsupported platform"; +#else + rc = ::pipe(pipe); +#endif + + if (rc != 0) { + return std::strerror(errno); + } + + // We may need to set FD_CLOEXEC on platforms without pipe2. +#if (!defined(__linux__) || !defined(__GLIBC__)) && !defined(_WIN32) + { + int flags = fcntl(pipe[0], F_GETFD, 0); + if (flags < 0) { + return std::strerror(errno); + } + rc = fcntl(pipe[0], F_SETFD, flags | FD_CLOEXEC); + if (rc < 0) { + return std::strerror(errno); + } + + flags = fcntl(pipe[1], F_GETFD, 0); + if (flags < 0) { + return std::strerror(errno); + } + rc = fcntl(pipe[1], F_SETFD, flags | FD_CLOEXEC); + if (rc < 0) { + return std::strerror(errno); + } + } +#endif + + // Make the write side nonblocking (the read side should stay blocking!) +#if defined(_WIN32) + return "Unsupported platform"; +#else + { + int flags = fcntl(pipe[1], F_GETFL, 0); + if (flags < 0) { + return std::strerror(errno); + } + rc = fcntl(pipe[1], F_SETFL, flags | O_NONBLOCK); + if (rc < 0) { + return std::strerror(errno); + } + } +#endif + + return ""; +} + +void InterruptThread() { + while (true) { + char buf = 0; + ssize_t bytes_read = read(pipe[0], &buf, 1); + if (bytes_read < 0) { + if (errno == EINTR) continue; + + // XXX: we failed reading from the pipe + std::string message = std::strerror(errno); + std::call_once(kWarnOnce, [&]() { + std::cerr << "adbc_driver_manager (native code): error handling interrupt: " + << message << std::endl; + }); + } else if (bytes_read > 0) { + std::lock_guard lock(cancel_mutex); + if (cancel_callback != nullptr) { + cancel_callback(cancel_callback_data); + } + cancel_callback = nullptr; + cancel_callback_data = nullptr; + } + } +} + +void SigintHandler(int) { (void)write(pipe[1], "X", 1); } + +} // namespace + +std::string InitBlockingCallback() { + std::string error; + std::call_once(kInitOnce, [&]() { + error = MakePipe(); + if (!error.empty()) { + return; + } + + our_sigint.sa_handler = &SigintHandler; + our_sigint.sa_flags = 0; + sigemptyset(&our_sigint.sa_mask); + + kCancelThread = std::thread(InterruptThread); + kCancelThread.detach(); + // TODO: set name of thread + }); + return error; +} + +std::string SetBlockingCallback(void (*callback)(void*), void* data) { + std::lock_guard lock(cancel_mutex); + cancel_callback = callback; + cancel_callback_data = data; + + int rc = sigaction(SIGINT, &our_sigint, &old_sigint); + if (rc != 0) { + return std::strerror(errno); + } + return ""; +} + +std::string ClearBlockingCallback() { + std::lock_guard lock(cancel_mutex); + cancel_callback = nullptr; + cancel_callback_data = nullptr; + + int rc = sigaction(SIGINT, &old_sigint, nullptr); + if (rc != 0) { + return std::strerror(errno); + } + return ""; +} + +} // namespace pyadbc_driver_manager diff --git a/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h new file mode 100644 index 0000000000..ac76252f3e --- /dev/null +++ b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Allow KeyboardInterrupt to function with ADBC in Python. +/// +/// Call SetBlockingCallback to register a callback. This will temporarily +/// suppress the Python SIGINT handler. When SIGINT is received, this module +/// will handle it by calling the callback. + +#include + +namespace pyadbc_driver_manager { + +/// \brief Set up internal state to handle. +/// \return An error message (or empty string). +std::string InitBlockingCallback(); +/// \brief Set the callback for when SIGINT is received. +/// \return An error message (or empty string). +std::string SetBlockingCallback(void (*callback)(void*), void* data); +/// \brief Clear the callback for when SIGINT is received. +/// \return An error message (or empty string). +std::string ClearBlockingCallback(); + +} // namespace pyadbc_driver_manager diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi index 7afada9ecc..2ef4672029 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi @@ -17,6 +17,7 @@ # NOTE: generated with mypy's stubgen, then hand-edited to fix things +import typing_extensions from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union from typing import overload @@ -201,3 +202,13 @@ def _test_error( vendor_code: Optional[int], sqlstate: Optional[str], ) -> Error: ... + +_P = typing_extensions.ParamSpec("_P") +_T = typing.TypeVar("_T") + +def _blocking_call( + func: typing.Callable[_P, _T], + args: _P.args, + kwargs: _P.kwargs, + cancel: typing.Callable[[], None], +) -> _T: ... diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index 91139100bb..103d568010 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -20,8 +20,10 @@ """Low-level ADBC API.""" import enum +import functools import threading import typing +import warnings from typing import List, Tuple cimport cpython @@ -33,6 +35,7 @@ from cpython.pycapsule cimport ( from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t from libc.stdlib cimport malloc, free from libc.string cimport memcpy, memset +from libcpp.string cimport string as c_string from libcpp.vector cimport vector as c_vector if typing.TYPE_CHECKING: @@ -1481,3 +1484,60 @@ cdef class AdbcStatement(_AdbcHandle): cdef const CAdbcError* PyAdbcErrorFromArrayStream( CArrowArrayStream* stream, CAdbcStatusCode* status): return AdbcErrorFromArrayStream(stream, status) + + +cdef extern from "_blocking_impl.h" nogil: + ctypedef void (*BlockingCallback)(void*) noexcept nogil + c_string CInitBlockingCallback"pyadbc_driver_manager::InitBlockingCallback"() + c_string CSetBlockingCallback"pyadbc_driver_manager::SetBlockingCallback"(BlockingCallback, void* data) + c_string CClearBlockingCallback"pyadbc_driver_manager::ClearBlockingCallback"() + + +@functools.cache +def _init_blocking_call(): + error = bytes(CInitBlockingCallback()).decode("utf-8") + if error: + warnings.warn( + f"Failed to initialize KeyboardInterrupt support: {error}", + RuntimeWarning, + ) + + +def _blocking_call(func, args, kwargs, cancel): + """ + Run functions that are expected to block with a native SIGINT handler. + + Parameters + ---------- + """ + if threading.current_thread() is not threading.main_thread(): + return func(*args, **kwargs) + + _init_blocking_call() + + # Set the callback for the background thread and save the signal handler + error = CSetBlockingCallback(&_handle_blocking_call, cancel) + if error: + warnings.warn( + f"Failed to set SIGINT handler: {error}", + RuntimeWarning, + ) + + try: + return func(*args, **kwargs) + finally: + # Restore the signal handler + error = CClearBlockingCallback() + if error: + warnings.warn( + f"Failed to restore SIGINT handler: {error}", + RuntimeWarning, + ) + + +cdef void _handle_blocking_call(void* c_cancel) noexcept nogil: + # TODO: if this throws, we could save and restore the traceback later + # TODO: we could record that this was hit and raise a KeyboardInterrupt above + with gil: + cancel = c_cancel + cancel() diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index 1e86144c12..01480a65e6 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -55,6 +55,7 @@ import adbc_driver_manager from . import _lib, _reader +from ._lib import _blocking_call if typing.TYPE_CHECKING: import pandas @@ -677,9 +678,12 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None: parameters, which will each be bound in turn). """ self._prepare_execute(operation, parameters) - handle, self._rowcount = self._stmt.execute_query() + + handle, self._rowcount = _blocking_call( + self._stmt.execute_query, [], {}, self._stmt.cancel + ) self._results = _RowIterator( - _reader.AdbcRecordBatchReader._import_from_c(handle.address) + self._stmt, _reader.AdbcRecordBatchReader._import_from_c(handle.address) ) def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None: @@ -991,7 +995,7 @@ def adbc_read_partition(self, partition: bytes) -> None: handle = self._conn._conn.read_partition(partition) self._rowcount = -1 self._results = _RowIterator( - pyarrow.RecordBatchReader._import_from_c(handle.address) + self._stmt, pyarrow.RecordBatchReader._import_from_c(handle.address) ) @property @@ -1095,7 +1099,8 @@ def fetch_record_batch(self) -> pyarrow.RecordBatchReader: class _RowIterator(_Closeable): """Track state needed to iterate over the result set.""" - def __init__(self, reader: pyarrow.RecordBatchReader) -> None: + def __init__(self, stmt, reader: pyarrow.RecordBatchReader) -> None: + self._stmt = stmt self._reader = reader self._current_batch = None self._next_row = 0 @@ -1118,7 +1123,9 @@ def fetchone(self) -> Optional[tuple]: if self._current_batch is None or self._next_row >= len(self._current_batch): try: while True: - self._current_batch = self._reader.read_next_batch() + self._current_batch = _blocking_call( + self._reader.read_next_batch, [], {}, self._stmt.cancel + ) if self._current_batch.num_rows > 0: break self._next_row = 0 diff --git a/python/adbc_driver_manager/pyproject.toml b/python/adbc_driver_manager/pyproject.toml index 0a03fa3ff9..d2db1f102f 100644 --- a/python/adbc_driver_manager/pyproject.toml +++ b/python/adbc_driver_manager/pyproject.toml @@ -23,6 +23,7 @@ license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.9" dynamic = ["version"] +dependencies = ["typing-extensions"] [project.optional-dependencies] dbapi = ["pandas", "pyarrow>=8.0.0"] diff --git a/python/adbc_driver_manager/setup.py b/python/adbc_driver_manager/setup.py index bbec1a01c6..ec0f83ba3c 100644 --- a/python/adbc_driver_manager/setup.py +++ b/python/adbc_driver_manager/setup.py @@ -93,6 +93,7 @@ def get_version(pkg_path): include_dirs=[str(source_root.joinpath("adbc_driver_manager").resolve())], language="c++", sources=[ + "adbc_driver_manager/_blocking_impl.cc", "adbc_driver_manager/_lib.pyx", "adbc_driver_manager/adbc_driver_manager.cc", ],