Skip to content

Commit daeb87b

Browse files
gridboxJohn Smith
and
John Smith
authored
Add sqlite commit and rollback hooks (#3500)
* fix: Derive clone for SqliteOperation * feat: Add sqlite commit and rollback hooks --------- Co-authored-by: John Smith <[email protected]>
1 parent 419877d commit daeb87b

File tree

3 files changed

+236
-4
lines changed

3 files changed

+236
-4
lines changed

sqlx-sqlite/src/connection/establish.rs

+2
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ impl EstablishParams {
296296
log_settings: self.log_settings.clone(),
297297
progress_handler_callback: None,
298298
update_hook_callback: None,
299+
commit_hook_callback: None,
300+
rollback_hook_callback: None,
299301
})
300302
}
301303
}

sqlx-sqlite/src/connection/mod.rs

+121-3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use futures_core::future::BoxFuture;
1111
use futures_intrusive::sync::MutexGuard;
1212
use futures_util::future;
1313
use libsqlite3_sys::{
14-
sqlite3, sqlite3_progress_handler, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT,
15-
SQLITE_UPDATE,
14+
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
15+
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
1616
};
1717

1818
pub(crate) use handle::ConnectionHandle;
@@ -63,7 +63,7 @@ pub struct LockedSqliteHandle<'a> {
6363
pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
6464
unsafe impl Send for Handler {}
6565

66-
#[derive(Debug, PartialEq, Eq)]
66+
#[derive(Debug, PartialEq, Eq, Clone)]
6767
pub enum SqliteOperation {
6868
Insert,
6969
Update,
@@ -91,6 +91,12 @@ pub struct UpdateHookResult<'a> {
9191
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
9292
unsafe impl Send for UpdateHookHandler {}
9393

94+
pub(crate) struct CommitHookHandler(NonNull<dyn FnMut() -> bool + Send + 'static>);
95+
unsafe impl Send for CommitHookHandler {}
96+
97+
pub(crate) struct RollbackHookHandler(NonNull<dyn FnMut() + Send + 'static>);
98+
unsafe impl Send for RollbackHookHandler {}
99+
94100
pub(crate) struct ConnectionState {
95101
pub(crate) handle: ConnectionHandle,
96102

@@ -106,6 +112,10 @@ pub(crate) struct ConnectionState {
106112
progress_handler_callback: Option<Handler>,
107113

108114
update_hook_callback: Option<UpdateHookHandler>,
115+
116+
commit_hook_callback: Option<CommitHookHandler>,
117+
118+
rollback_hook_callback: Option<RollbackHookHandler>,
109119
}
110120

111121
impl ConnectionState {
@@ -127,6 +137,24 @@ impl ConnectionState {
127137
}
128138
}
129139
}
140+
141+
pub(crate) fn remove_commit_hook(&mut self) {
142+
if let Some(mut handler) = self.commit_hook_callback.take() {
143+
unsafe {
144+
sqlite3_commit_hook(self.handle.as_ptr(), None, ptr::null_mut());
145+
let _ = { Box::from_raw(handler.0.as_mut()) };
146+
}
147+
}
148+
}
149+
150+
pub(crate) fn remove_rollback_hook(&mut self) {
151+
if let Some(mut handler) = self.rollback_hook_callback.take() {
152+
unsafe {
153+
sqlite3_rollback_hook(self.handle.as_ptr(), None, ptr::null_mut());
154+
let _ = { Box::from_raw(handler.0.as_mut()) };
155+
}
156+
}
157+
}
130158
}
131159

132160
pub(crate) struct Statements {
@@ -284,6 +312,31 @@ extern "C" fn update_hook<F>(
284312
}
285313
}
286314

315+
extern "C" fn commit_hook<F>(callback: *mut c_void) -> c_int
316+
where
317+
F: FnMut() -> bool,
318+
{
319+
unsafe {
320+
let r = catch_unwind(|| {
321+
let callback: *mut F = callback.cast::<F>();
322+
(*callback)()
323+
});
324+
c_int::from(!r.unwrap_or_default())
325+
}
326+
}
327+
328+
extern "C" fn rollback_hook<F>(callback: *mut c_void)
329+
where
330+
F: FnMut(),
331+
{
332+
unsafe {
333+
let _ = catch_unwind(|| {
334+
let callback: *mut F = callback.cast::<F>();
335+
(*callback)()
336+
});
337+
}
338+
}
339+
287340
impl LockedSqliteHandle<'_> {
288341
/// Returns the underlying sqlite3* connection handle.
289342
///
@@ -368,6 +421,61 @@ impl LockedSqliteHandle<'_> {
368421
}
369422
}
370423

424+
/// Sets a commit hook that is invoked whenever a transaction is committed. If the commit hook callback
425+
/// returns `false`, then the operation is turned into a ROLLBACK.
426+
///
427+
/// Only a single commit hook may be defined at one time per database connection; setting a new commit hook
428+
/// overrides the old one.
429+
///
430+
/// The commit hook callback must not do anything that will modify the database connection that invoked
431+
/// the commit hook. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections
432+
/// in this context.
433+
///
434+
/// See https://www.sqlite.org/c3ref/commit_hook.html
435+
pub fn set_commit_hook<F>(&mut self, callback: F)
436+
where
437+
F: FnMut() -> bool + Send + 'static,
438+
{
439+
unsafe {
440+
let callback_boxed = Box::new(callback);
441+
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
442+
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
443+
let handler = callback.as_ptr() as *mut _;
444+
self.guard.remove_commit_hook();
445+
self.guard.commit_hook_callback = Some(CommitHookHandler(callback));
446+
447+
sqlite3_commit_hook(
448+
self.as_raw_handle().as_mut(),
449+
Some(commit_hook::<F>),
450+
handler,
451+
);
452+
}
453+
}
454+
455+
/// Sets a rollback hook that is invoked whenever a transaction rollback occurs. The rollback callback is not
456+
/// invoked if a transaction is automatically rolled back because the database connection is closed.
457+
///
458+
/// See https://www.sqlite.org/c3ref/commit_hook.html
459+
pub fn set_rollback_hook<F>(&mut self, callback: F)
460+
where
461+
F: FnMut() + Send + 'static,
462+
{
463+
unsafe {
464+
let callback_boxed = Box::new(callback);
465+
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
466+
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
467+
let handler = callback.as_ptr() as *mut _;
468+
self.guard.remove_rollback_hook();
469+
self.guard.rollback_hook_callback = Some(RollbackHookHandler(callback));
470+
471+
sqlite3_rollback_hook(
472+
self.as_raw_handle().as_mut(),
473+
Some(rollback_hook::<F>),
474+
handler,
475+
);
476+
}
477+
}
478+
371479
/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
372480
pub fn remove_progress_handler(&mut self) {
373481
self.guard.remove_progress_handler();
@@ -376,6 +484,14 @@ impl LockedSqliteHandle<'_> {
376484
pub fn remove_update_hook(&mut self) {
377485
self.guard.remove_update_hook();
378486
}
487+
488+
pub fn remove_commit_hook(&mut self) {
489+
self.guard.remove_commit_hook();
490+
}
491+
492+
pub fn remove_rollback_hook(&mut self) {
493+
self.guard.remove_rollback_hook();
494+
}
379495
}
380496

381497
impl Drop for ConnectionState {
@@ -384,6 +500,8 @@ impl Drop for ConnectionState {
384500
self.statements.clear();
385501
self.remove_progress_handler();
386502
self.remove_update_hook();
503+
self.remove_commit_hook();
504+
self.remove_rollback_hook();
387505
}
388506
}
389507

tests/sqlite/sqlite.rs

+113-1
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,7 @@ async fn test_query_with_update_hook() -> anyhow::Result<()> {
806806
assert_eq!(result.operation, SqliteOperation::Insert);
807807
assert_eq!(result.database, "main");
808808
assert_eq!(result.table, "tweet");
809-
assert_eq!(result.rowid, 3);
809+
assert_eq!(result.rowid, 2);
810810
});
811811

812812
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )")
@@ -848,3 +848,115 @@ async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Resul
848848
assert_eq!(1, Arc::strong_count(&ref_counted_object));
849849
Ok(())
850850
}
851+
852+
#[sqlx_macros::test]
853+
async fn test_query_with_commit_hook() -> anyhow::Result<()> {
854+
let mut conn = new::<Sqlite>().await?;
855+
856+
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
857+
let state = format!("test");
858+
conn.lock_handle().await?.set_commit_hook(move || {
859+
assert_eq!(state, "test");
860+
false
861+
});
862+
863+
let mut tx = conn.begin().await?;
864+
sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 4, 'Hello, World' )")
865+
.execute(&mut *tx)
866+
.await?;
867+
match tx.commit().await {
868+
Err(sqlx::Error::Database(err)) => {
869+
assert_eq!(err.message(), String::from("constraint failed"))
870+
}
871+
_ => panic!("expected an error"),
872+
}
873+
874+
Ok(())
875+
}
876+
877+
#[sqlx_macros::test]
878+
async fn test_multiple_set_commit_hook_calls_drop_old_handler() -> anyhow::Result<()> {
879+
let ref_counted_object = Arc::new(0);
880+
assert_eq!(1, Arc::strong_count(&ref_counted_object));
881+
882+
{
883+
let mut conn = new::<Sqlite>().await?;
884+
885+
let o = ref_counted_object.clone();
886+
conn.lock_handle().await?.set_commit_hook(move || {
887+
println!("{o:?}");
888+
true
889+
});
890+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
891+
892+
let o = ref_counted_object.clone();
893+
conn.lock_handle().await?.set_commit_hook(move || {
894+
println!("{o:?}");
895+
true
896+
});
897+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
898+
899+
let o = ref_counted_object.clone();
900+
conn.lock_handle().await?.set_commit_hook(move || {
901+
println!("{o:?}");
902+
true
903+
});
904+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
905+
906+
conn.lock_handle().await?.remove_commit_hook();
907+
}
908+
909+
assert_eq!(1, Arc::strong_count(&ref_counted_object));
910+
Ok(())
911+
}
912+
913+
#[sqlx_macros::test]
914+
async fn test_query_with_rollback_hook() -> anyhow::Result<()> {
915+
let mut conn = new::<Sqlite>().await?;
916+
917+
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
918+
let state = format!("test");
919+
conn.lock_handle().await?.set_rollback_hook(move || {
920+
assert_eq!(state, "test");
921+
});
922+
923+
let mut tx = conn.begin().await?;
924+
sqlx::query("INSERT INTO tweet ( id, text ) VALUES (5, 'Hello, World' )")
925+
.execute(&mut *tx)
926+
.await?;
927+
tx.rollback().await?;
928+
Ok(())
929+
}
930+
931+
#[sqlx_macros::test]
932+
async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Result<()> {
933+
let ref_counted_object = Arc::new(0);
934+
assert_eq!(1, Arc::strong_count(&ref_counted_object));
935+
936+
{
937+
let mut conn = new::<Sqlite>().await?;
938+
939+
let o = ref_counted_object.clone();
940+
conn.lock_handle().await?.set_rollback_hook(move || {
941+
println!("{o:?}");
942+
});
943+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
944+
945+
let o = ref_counted_object.clone();
946+
conn.lock_handle().await?.set_rollback_hook(move || {
947+
println!("{o:?}");
948+
});
949+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
950+
951+
let o = ref_counted_object.clone();
952+
conn.lock_handle().await?.set_rollback_hook(move || {
953+
println!("{o:?}");
954+
});
955+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
956+
957+
conn.lock_handle().await?.remove_rollback_hook();
958+
}
959+
960+
assert_eq!(1, Arc::strong_count(&ref_counted_object));
961+
Ok(())
962+
}

0 commit comments

Comments
 (0)