Skip to content

Commit

Permalink
feat(python/adbc_driver_manager): handle KeyboardInterrupt
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Feb 2, 2024
1 parent 5e21134 commit ec7525c
Show file tree
Hide file tree
Showing 11 changed files with 365 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/native-unix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ jobs:
- name: Test Python Driver Flight SQL
shell: bash -l {0}
run: |
docker compose up -d --wait flightsql-test
export ADBC_TEST_FLIGHTSQL_URI=grpc://localhost:41414
env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local"
- name: Build Python Driver PostgreSQL
shell: bash -l {0}
Expand Down
6 changes: 6 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ services:
dockerfile: ci/docker/flightsql-test.dockerfile
args:
GO: ${GO}
healthcheck:
test: ["CMD", "curl", "--http2-prior-knowledge", "-XPOST", "-H", "content-type: application/grpc"]
interval: 5s
timeout: 30s
retries: 3
start_period: 5m
ports:
- "41414:41414"
volumes:
Expand Down
48 changes: 48 additions & 0 deletions python/adbc_driver_flightsql/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
# specific language governing permissions and limitations
# under the License.

import os
import re
import signal
import threading
import time

import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
Expand Down Expand Up @@ -45,6 +49,50 @@ def test_query_cancel(test_dbapi):
cur.fetchone()


def test_query_cancel_async(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("forever")

def _cancel():
time.sleep(2)
cur.adbc_cancel()

t = threading.Thread(target=_cancel, daemon=True)
t.start()

with pytest.raises(
test_dbapi.OperationalError,
match=re.escape("CANCELLED: [FlightSQL] context canceled"),
):
cur.fetchone()


def test_query_cancel_sigint(test_dbapi):
with test_dbapi.cursor() as cur:
for _ in range(3):
cur.execute("forever")

# XXX: this handles the case of DoGet taking forever, but we also will
# want to test GetFlightInfo taking forever

def _cancel():
time.sleep(2)
os.kill(os.getpid(), signal.SIGINT)

t = threading.Thread(target=_cancel, daemon=True)
t.start()
with pytest.raises(
test_dbapi.OperationalError,
match=re.escape("CANCELLED: [FlightSQL] context canceled"),
):
cur.fetchone()

# The cursor should still be usable
cur.execute("error_do_get")
with pytest.raises(test_dbapi.ProgrammingError):
cur.fetch_arrow_table()


def test_query_error_fetch(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("error_do_get")
Expand Down
2 changes: 2 additions & 0 deletions python/adbc_driver_manager/MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ include NOTICE.txt
include adbc_driver_manager/adbc.h
include adbc_driver_manager/adbc_driver_manager.cc
include adbc_driver_manager/adbc_driver_manager.h
include adbc_driver_manager/_blocking_impl.cc
include adbc_driver_manager/_blocking_impl.h
include adbc_driver_manager/_lib.pxd
include adbc_driver_manager/_lib.pyi
include adbc_driver_manager/_reader.pyi
Expand Down
182 changes: 182 additions & 0 deletions python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "_blocking_impl.h"

#include <fcntl.h>
#include <unistd.h>

#include <csignal>
#include <cstring>
#include <iostream>
#include <mutex>
#include <thread>

namespace pyadbc_driver_manager {

// This is somewhat derived from io_util.cc in arrow, but that implementation
// isn't easily used outside of Arrow's monolith.
namespace {
static std::once_flag kInitOnce;
// We may encounter errors below that we can't do anything about. Use this to
// print out an error, once.
static std::once_flag kWarnOnce;
// This thread reads from a pipe forever. Whenver it reads something, it
// calls the callback below.
static std::thread kCancelThread;

static std::mutex cancel_mutex;
// This callback is registered by the Python side; basically it will call
// cancel() on an ADBC object.
static void (*cancel_callback)(void*) = nullptr;
// Callback state (a pointer to the ADBC PyObject).
static void* cancel_callback_data = nullptr;
// A nonblocking self-pipe.
static int pipe[2];
// The old signal handler (most likely Python's).
struct sigaction old_sigint;
// Our signal handler (below).
struct sigaction our_sigint;

std::string MakePipe() {
int rc = 0;
#if defined(__linux__) && defined(__GLIBC__)
rc = pipe2(pipe, O_CLOEXEC);
#elif defined(_WIN32)
return "Unsupported platform";
#else
rc = ::pipe(pipe);
#endif

if (rc != 0) {
return std::strerror(errno);
}

// We may need to set FD_CLOEXEC on platforms without pipe2.
#if (!defined(__linux__) || !defined(__GLIBC__)) && !defined(_WIN32)
{
int flags = fcntl(pipe[0], F_GETFD, 0);
if (flags < 0) {
return std::strerror(errno);
}
rc = fcntl(pipe[0], F_SETFD, flags | FD_CLOEXEC);
if (rc < 0) {
return std::strerror(errno);
}

flags = fcntl(pipe[1], F_GETFD, 0);
if (flags < 0) {
return std::strerror(errno);
}
rc = fcntl(pipe[1], F_SETFD, flags | FD_CLOEXEC);
if (rc < 0) {
return std::strerror(errno);
}
}
#endif

// Make the write side nonblocking (the read side should stay blocking!)
#if defined(_WIN32)
return "Unsupported platform";
#else
{
int flags = fcntl(pipe[1], F_GETFL, 0);
if (flags < 0) {
return std::strerror(errno);
}
rc = fcntl(pipe[1], F_SETFL, flags | O_NONBLOCK);
if (rc < 0) {
return std::strerror(errno);
}
}
#endif

return "";
}

void InterruptThread() {
while (true) {
char buf = 0;
ssize_t bytes_read = read(pipe[0], &buf, 1);
if (bytes_read < 0) {
if (errno == EINTR) continue;

// XXX: we failed reading from the pipe
std::string message = std::strerror(errno);
std::call_once(kWarnOnce, [&]() {
std::cerr << "adbc_driver_manager (native code): error handling interrupt: "
<< message << std::endl;
});
} else if (bytes_read > 0) {
std::lock_guard<std::mutex> lock(cancel_mutex);
if (cancel_callback != nullptr) {
cancel_callback(cancel_callback_data);
}
cancel_callback = nullptr;
cancel_callback_data = nullptr;
}
}
}

void SigintHandler(int) { (void)write(pipe[1], "X", 1); }

} // namespace

std::string InitBlockingCallback() {
std::string error;
std::call_once(kInitOnce, [&]() {
error = MakePipe();
if (!error.empty()) {
return;
}

our_sigint.sa_handler = &SigintHandler;
our_sigint.sa_flags = 0;
sigemptyset(&our_sigint.sa_mask);

kCancelThread = std::thread(InterruptThread);
kCancelThread.detach();
// TODO: set name of thread
});
return error;
}

std::string SetBlockingCallback(void (*callback)(void*), void* data) {
std::lock_guard<std::mutex> lock(cancel_mutex);
cancel_callback = callback;
cancel_callback_data = data;

int rc = sigaction(SIGINT, &our_sigint, &old_sigint);
if (rc != 0) {
return std::strerror(errno);
}
return "";
}

std::string ClearBlockingCallback() {
std::lock_guard<std::mutex> lock(cancel_mutex);
cancel_callback = nullptr;
cancel_callback_data = nullptr;

int rc = sigaction(SIGINT, &old_sigint, nullptr);
if (rc != 0) {
return std::strerror(errno);
}
return "";
}

} // namespace pyadbc_driver_manager
38 changes: 38 additions & 0 deletions python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

/// Allow KeyboardInterrupt to function with ADBC in Python.
///
/// Call SetBlockingCallback to register a callback. This will temporarily
/// suppress the Python SIGINT handler. When SIGINT is received, this module
/// will handle it by calling the callback.

#include <string>

namespace pyadbc_driver_manager {

/// \brief Set up internal state to handle.
/// \return An error message (or empty string).
std::string InitBlockingCallback();
/// \brief Set the callback for when SIGINT is received.
/// \return An error message (or empty string).
std::string SetBlockingCallback(void (*callback)(void*), void* data);
/// \brief Clear the callback for when SIGINT is received.
/// \return An error message (or empty string).
std::string ClearBlockingCallback();

} // namespace pyadbc_driver_manager
11 changes: 11 additions & 0 deletions python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# NOTE: generated with mypy's stubgen, then hand-edited to fix things

import typing_extensions
from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union

from typing import overload
Expand Down Expand Up @@ -201,3 +202,13 @@ def _test_error(
vendor_code: Optional[int],
sqlstate: Optional[str],
) -> Error: ...

_P = typing_extensions.ParamSpec("_P")
_T = typing.TypeVar("_T")

def _blocking_call(
func: typing.Callable[_P, _T],
args: _P.args,
kwargs: _P.kwargs,
cancel: typing.Callable[[], None],
) -> _T: ...
Loading

0 comments on commit ec7525c

Please sign in to comment.