Skip to content

Commit b48ad2c

Browse files
nbaztecAandreba
authored andcommitted
add progress handler support to sqlite (launchbadge#2256)
* rebase main * fmt * use NonNull to fix UB * apply code suggestions * add test for multiple handler drops * remove nightly features for test
1 parent 0897000 commit b48ad2c

File tree

3 files changed

+148
-1
lines changed

3 files changed

+148
-1
lines changed

sqlx-sqlite/src/connection/establish.rs

+1
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ impl EstablishParams {
282282
statements: Statements::new(self.statement_cache_capacity),
283283
transaction_depth: 0,
284284
log_settings: self.log_settings.clone(),
285+
progress_handler_callback: None,
285286
})
286287
}
287288
}

sqlx-sqlite/src/connection/mod.rs

+78-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
use futures_core::future::BoxFuture;
22
use futures_intrusive::sync::MutexGuard;
33
use futures_util::future;
4-
use libsqlite3_sys::sqlite3;
4+
use libsqlite3_sys::{sqlite3, sqlite3_progress_handler};
55
use sqlx_core::common::StatementCache;
66
use sqlx_core::error::Error;
77
use sqlx_core::transaction::Transaction;
88
use std::cmp::Ordering;
99
use std::fmt::{self, Debug, Formatter};
10+
use std::os::raw::{c_int, c_void};
11+
use std::panic::catch_unwind;
1012
use std::ptr::NonNull;
1113

1214
use crate::connection::establish::EstablishParams;
@@ -51,6 +53,10 @@ pub struct LockedSqliteHandle<'a> {
5153
pub(crate) guard: MutexGuard<'a, ConnectionState>,
5254
}
5355

56+
/// Represents a callback handler that will be shared with the underlying sqlite3 connection.
57+
pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
58+
unsafe impl Send for Handler {}
59+
5460
pub(crate) struct ConnectionState {
5561
pub(crate) handle: ConnectionHandle,
5662

@@ -60,6 +66,22 @@ pub(crate) struct ConnectionState {
6066
pub(crate) statements: Statements,
6167

6268
log_settings: LogSettings,
69+
70+
/// Stores the progress handler set on the current connection. If the handler returns `false`,
71+
/// the query is interrupted.
72+
progress_handler_callback: Option<Handler>,
73+
}
74+
75+
impl ConnectionState {
76+
/// Drops the `progress_handler_callback` if it exists.
77+
pub(crate) fn remove_progress_handler(&mut self) {
78+
if let Some(mut handler) = self.progress_handler_callback.take() {
79+
unsafe {
80+
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _);
81+
let _ = { Box::from_raw(handler.0.as_mut()) };
82+
}
83+
}
84+
}
6385
}
6486

6587
pub(crate) struct Statements {
@@ -177,6 +199,21 @@ impl Connection for SqliteConnection {
177199
}
178200
}
179201

202+
/// Implements a C binding to a progress callback. The function returns `0` if the
203+
/// user-provided callback returns `true`, and `1` otherwise to signal an interrupt.
204+
extern "C" fn progress_callback<F>(callback: *mut c_void) -> c_int
205+
where
206+
F: FnMut() -> bool,
207+
{
208+
unsafe {
209+
let r = catch_unwind(|| {
210+
let callback: *mut F = callback.cast::<F>();
211+
(*callback)()
212+
});
213+
c_int::from(!r.unwrap_or_default())
214+
}
215+
}
216+
180217
impl LockedSqliteHandle<'_> {
181218
/// Returns the underlying sqlite3* connection handle.
182219
///
@@ -206,12 +243,52 @@ impl LockedSqliteHandle<'_> {
206243
) -> Result<(), Error> {
207244
collation::create_collation(&mut self.guard.handle, name, compare)
208245
}
246+
247+
/// Sets a progress handler that is invoked periodically during long running calls. If the progress callback
248+
/// returns `false`, then the operation is interrupted.
249+
///
250+
/// `num_ops` is the approximate number of [virtual machine instructions](https://www.sqlite.org/opcode.html)
251+
/// that are evaluated between successive invocations of the callback. If `num_ops` is less than one then the
252+
/// progress handler is disabled.
253+
///
254+
/// Only a single progress handler may be defined at one time per database connection; setting a new progress
255+
/// handler cancels the old one.
256+
///
257+
/// The progress handler callback must not do anything that will modify the database connection that invoked
258+
/// the progress handler. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections
259+
/// in this context.
260+
pub fn set_progress_handler<F>(&mut self, num_ops: i32, mut callback: F)
261+
where
262+
F: FnMut() -> bool + Send + 'static,
263+
{
264+
unsafe {
265+
let callback_boxed = Box::new(callback);
266+
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
267+
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
268+
let handler = callback.as_ptr() as *mut _;
269+
self.guard.remove_progress_handler();
270+
self.guard.progress_handler_callback = Some(Handler(callback));
271+
272+
sqlite3_progress_handler(
273+
self.as_raw_handle().as_mut(),
274+
num_ops,
275+
Some(progress_callback::<F>),
276+
handler,
277+
);
278+
}
279+
}
280+
281+
/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
282+
pub fn remove_progress_handler(&mut self) {
283+
self.guard.remove_progress_handler();
284+
}
209285
}
210286

211287
impl Drop for ConnectionState {
212288
fn drop(&mut self) {
213289
// explicitly drop statements before the connection handle is dropped
214290
self.statements.clear();
291+
self.remove_progress_handler();
215292
}
216293
}
217294

tests/sqlite/sqlite.rs

+69
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use sqlx::{
77
SqliteConnection, SqlitePool, Statement, TypeInfo,
88
};
99
use sqlx_test::new;
10+
use std::sync::Arc;
1011

1112
#[sqlx_macros::test]
1213
async fn it_connects() -> anyhow::Result<()> {
@@ -725,3 +726,71 @@ async fn concurrent_read_and_write() {
725726
read.await;
726727
write.await;
727728
}
729+
730+
#[sqlx_macros::test]
731+
async fn test_query_with_progress_handler() -> anyhow::Result<()> {
732+
let mut conn = new::<Sqlite>().await?;
733+
734+
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
735+
let state = format!("test");
736+
conn.lock_handle().await?.set_progress_handler(1, move || {
737+
assert_eq!(state, "test");
738+
false
739+
});
740+
741+
match sqlx::query("SELECT 'hello' AS title")
742+
.fetch_all(&mut conn)
743+
.await
744+
{
745+
Err(sqlx::Error::Database(err)) => assert_eq!(err.message(), String::from("interrupted")),
746+
_ => panic!("expected an interrupt"),
747+
}
748+
749+
Ok(())
750+
}
751+
752+
#[sqlx_macros::test]
753+
async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::Result<()> {
754+
let ref_counted_object = Arc::new(0);
755+
assert_eq!(1, Arc::strong_count(&ref_counted_object));
756+
757+
{
758+
let mut conn = new::<Sqlite>().await?;
759+
760+
let o = ref_counted_object.clone();
761+
conn.lock_handle().await?.set_progress_handler(1, move || {
762+
println!("{:?}", o);
763+
false
764+
});
765+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
766+
767+
let o = ref_counted_object.clone();
768+
conn.lock_handle().await?.set_progress_handler(1, move || {
769+
println!("{:?}", o);
770+
false
771+
});
772+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
773+
774+
let o = ref_counted_object.clone();
775+
conn.lock_handle().await?.set_progress_handler(1, move || {
776+
println!("{:?}", o);
777+
false
778+
});
779+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
780+
781+
match sqlx::query("SELECT 'hello' AS title")
782+
.fetch_all(&mut conn)
783+
.await
784+
{
785+
Err(sqlx::Error::Database(err)) => {
786+
assert_eq!(err.message(), String::from("interrupted"))
787+
}
788+
_ => panic!("expected an interrupt"),
789+
}
790+
791+
conn.lock_handle().await?.remove_progress_handler();
792+
}
793+
794+
assert_eq!(1, Arc::strong_count(&ref_counted_object));
795+
Ok(())
796+
}

0 commit comments

Comments
 (0)