diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 7e5bc4746690..8c8568aa94a0 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -147,7 +147,7 @@ impl crate::r2d2::R2D2Connection for MysqlConnection { self.transaction_state .status .transaction_depth() - .map(|d| d.is_none()) + .map(|d| d.is_some()) .unwrap_or(true) } } diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index c1a6cf3c8b44..d55724b69cd4 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -180,7 +180,7 @@ impl crate::r2d2::R2D2Connection for PgConnection { self.transaction_state .status .transaction_depth() - .map(|d| d.is_none()) + .map(|d| d.is_some()) .unwrap_or(true) } } diff --git a/diesel/src/r2d2.rs b/diesel/src/r2d2.rs index ddac7c96e818..62f5efb9eb10 100644 --- a/diesel/src/r2d2.rs +++ b/diesel/src/r2d2.rs @@ -327,6 +327,28 @@ where } } +#[derive(QueryId)] +pub(crate) struct CheckConnectionQuery; + +impl QueryFragment for CheckConnectionQuery +where + DB: Backend, +{ + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, DB>, + ) -> QueryResult<()> { + pass.push_sql("SELECT 1"); + Ok(()) + } +} + +impl Query for CheckConnectionQuery { + type SqlType = crate::sql_types::Integer; +} + +impl RunQueryDsl for CheckConnectionQuery {} + #[cfg(test)] mod tests { use std::sync::mpsc; @@ -394,26 +416,110 @@ mod tests { let query = select("foo".into_sql::()); assert_eq!("foo", query.get_result::(&mut conn).unwrap()); } -} -#[derive(QueryId)] -pub(crate) struct CheckConnectionQuery; + #[test] + fn check_pool_does_actually_hold_connections() { + use std::sync::atomic::{AtomicU32, Ordering}; + + #[derive(Debug)] + struct TestEventHandler { + acquire_count: Arc, + release_count: Arc, + checkin_count: Arc, + checkout_count: Arc, + } -impl QueryFragment for CheckConnectionQuery -where - DB: Backend, -{ - fn walk_ast<'b>( - &'b self, - mut pass: crate::query_builder::AstPass<'_, 'b, DB>, - ) -> QueryResult<()> { - pass.push_sql("SELECT 1"); - Ok(()) - } -} + impl r2d2::HandleEvent for TestEventHandler { + fn handle_acquire(&self, _event: r2d2::event::AcquireEvent) { + self.acquire_count.fetch_add(1, Ordering::Relaxed); + } + fn handle_release(&self, _event: r2d2::event::ReleaseEvent) { + self.release_count.fetch_add(1, Ordering::Relaxed); + } + fn handle_checkout(&self, _event: r2d2::event::CheckoutEvent) { + self.checkout_count.fetch_add(1, Ordering::Relaxed); + } + fn handle_checkin(&self, _event: r2d2::event::CheckinEvent) { + self.checkin_count.fetch_add(1, Ordering::Relaxed); + } + } -impl Query for CheckConnectionQuery { - type SqlType = crate::sql_types::Integer; -} + let acquire_count = Arc::new(AtomicU32::new(0)); + let release_count = Arc::new(AtomicU32::new(0)); + let checkin_count = Arc::new(AtomicU32::new(0)); + let checkout_count = Arc::new(AtomicU32::new(0)); -impl RunQueryDsl for CheckConnectionQuery {} + let handler = Box::new(TestEventHandler { + acquire_count: acquire_count.clone(), + release_count: release_count.clone(), + checkin_count: checkin_count.clone(), + checkout_count: checkout_count.clone(), + }); + + let manager = ConnectionManager::::new(database_url()); + let pool = Pool::builder() + .max_size(1) + .test_on_check_out(true) + .event_handler(handler) + .build(manager) + .unwrap(); + + assert_eq!(acquire_count.load(Ordering::Relaxed), 1); + assert_eq!(release_count.load(Ordering::Relaxed), 0); + assert_eq!(checkin_count.load(Ordering::Relaxed), 0); + assert_eq!(checkout_count.load(Ordering::Relaxed), 0); + + // check that we reuse connections with the pool + { + let conn = pool.get().unwrap(); + + assert_eq!(acquire_count.load(Ordering::Relaxed), 1); + assert_eq!(release_count.load(Ordering::Relaxed), 0); + assert_eq!(checkin_count.load(Ordering::Relaxed), 0); + assert_eq!(checkout_count.load(Ordering::Relaxed), 1); + std::mem::drop(conn); + } + + assert_eq!(acquire_count.load(Ordering::Relaxed), 1); + assert_eq!(release_count.load(Ordering::Relaxed), 0); + assert_eq!(checkin_count.load(Ordering::Relaxed), 1); + assert_eq!(checkout_count.load(Ordering::Relaxed), 1); + + // check that we remove a connection with open transactions from the pool + { + let mut conn = pool.get().unwrap(); + + assert_eq!(acquire_count.load(Ordering::Relaxed), 1); + assert_eq!(release_count.load(Ordering::Relaxed), 0); + assert_eq!(checkin_count.load(Ordering::Relaxed), 1); + assert_eq!(checkout_count.load(Ordering::Relaxed), 2); + + ::TransactionManager::begin_transaction(&mut *conn) + .unwrap(); + } + + assert_eq!(acquire_count.load(Ordering::Relaxed), 1); + assert_eq!(release_count.load(Ordering::Relaxed), 1); + assert_eq!(checkin_count.load(Ordering::Relaxed), 2); + assert_eq!(checkout_count.load(Ordering::Relaxed), 2); + + // check that we remove a connection from the pool that was + // open during panicing + #[allow(unreachable_code, unused_variables)] + std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let conn = pool.get(); + assert_eq!(acquire_count.load(Ordering::Relaxed), 2); + assert_eq!(release_count.load(Ordering::Relaxed), 1); + assert_eq!(checkin_count.load(Ordering::Relaxed), 2); + assert_eq!(checkout_count.load(Ordering::Relaxed), 3); + panic!(); + std::mem::drop(conn); + })) + .unwrap_err(); + + assert_eq!(acquire_count.load(Ordering::Relaxed), 2); + assert_eq!(release_count.load(Ordering::Relaxed), 2); + assert_eq!(checkin_count.load(Ordering::Relaxed), 3); + assert_eq!(checkout_count.load(Ordering::Relaxed), 3); + } +} diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 3d85958519ba..e85ac7883c89 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -144,7 +144,7 @@ impl crate::r2d2::R2D2Connection for crate::sqlite::SqliteConnection { self.transaction_state .status .transaction_depth() - .map(|d| d.is_none()) + .map(|d| d.is_some()) .unwrap_or(true) } }