From a80460ed549983af65dc6ad6aa57dac5acb926ed Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 25 Dec 2020 22:06:58 +0100 Subject: [PATCH] Resurrect on_release support from #89 by spawning --- bb8/src/api.rs | 19 ++++--- bb8/src/inner.rs | 24 ++++++++- bb8/src/internals.rs | 6 +-- bb8/tests/test.rs | 122 ++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 158 insertions(+), 13 deletions(-) diff --git a/bb8/src/api.rs b/bb8/src/api.rs index 9616d43..948874e 100644 --- a/bb8/src/api.rs +++ b/bb8/src/api.rs @@ -273,13 +273,18 @@ pub trait CustomizeConnection: /// Called with connections immediately after they are returned from /// `ManageConnection::connect`. /// - /// The default implementation simply returns `Ok(())`. - /// - /// # Errors + /// The default implementation simply returns `Ok(())`. Any errors will be forwarded to the + /// configured error sink. + async fn on_acquire(&self, _connection: &mut C) -> Result<(), E> { + Ok(()) + } + + /// Called with connections before they're returned to the connection pool. /// - /// If this method returns an error, the connection will be discarded. + /// The default implementation simply returns `Ok(())`. Any errors will be forwarded to the + /// configured error sink. #[allow(unused_variables)] - async fn on_acquire(&self, connection: &mut C) -> Result<(), E> { + async fn on_release(&'_ self, _connection: &'_ mut C) -> Result<(), E> { Ok(()) } } @@ -304,8 +309,8 @@ where } } - pub(crate) fn drop_invalid(mut self) { - let _ = self.conn.take(); + pub(crate) fn extract(mut self) -> Conn { + self.conn.take().unwrap() } } diff --git a/bb8/src/inner.rs b/bb8/src/inner.rs index 5435b7f..28eea65 100644 --- a/bb8/src/inner.rs +++ b/bb8/src/inner.rs @@ -102,7 +102,9 @@ where match self.inner.manager.is_valid(&mut conn).await { Ok(()) => return Ok(conn), Err(_) => { - conn.drop_invalid(); + self.on_release_connection(conn.extract()); + // Once we've extracted the connection, the `Drop` impl for `PooledConnection` + // will call `put_back(None)`, so we don't need to do anything else here. continue; } } @@ -133,6 +135,7 @@ where if !self.inner.manager.has_broken(&mut conn.conn) { Some(conn) } else { + self.on_release_connection(conn); None } }); @@ -147,6 +150,25 @@ where } } + fn on_release_connection(&self, mut conn: Conn) { + if self.inner.statics.connection_customizer.is_none() { + return; + } + + let pool = self.inner.clone(); + spawn(async move { + let customizer = match pool.statics.connection_customizer.as_ref() { + Some(customizer) => customizer, + None => return, + }; + + let future = customizer.on_release(&mut conn.conn); + if let Err(e) = future.await { + pool.statics.error_sink.sink(e); + } + }); + } + /// Returns information about the current state of the pool. pub(crate) fn state(&self) -> State { self.inner.internals.lock().state() diff --git a/bb8/src/internals.rs b/bb8/src/internals.rs index 5a8e122..7d4518d 100644 --- a/bb8/src/internals.rs +++ b/bb8/src/internals.rs @@ -168,12 +168,12 @@ where } pub(crate) struct InternalsGuard { - conn: Option>, - pool: Arc>, + pub(crate) conn: Option>, + pub(crate) pool: Arc>, } impl InternalsGuard { - fn new(conn: Conn, pool: Arc>) -> Self { + pub(crate) fn new(conn: Conn, pool: Arc>) -> Self { Self { conn: Some(conn), pool, diff --git a/bb8/tests/test.rs b/bb8/tests/test.rs index bbf1c95..d7d7767 100644 --- a/bb8/tests/test.rs +++ b/bb8/tests/test.rs @@ -5,7 +5,7 @@ use std::iter::FromIterator; use std::marker::PhantomData; use std::pin::Pin; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; use std::task::Poll; use std::time::Duration; use std::{error, fmt, mem}; @@ -14,7 +14,7 @@ use async_trait::async_trait; use futures_channel::oneshot; use futures_util::future::{err, lazy, ok, pending, ready, try_join_all, FutureExt}; use futures_util::stream::{FuturesUnordered, TryStreamExt}; -use tokio::time::timeout; +use tokio::time::{sleep, timeout}; #[derive(Debug, PartialEq, Eq)] pub struct Error; @@ -786,3 +786,121 @@ async fn test_customize_connection_acquire() { let connection_1_or_2 = pool.get().await.unwrap(); assert!(connection_1_or_2.custom_field == 1 || connection_1_or_2.custom_field == 2); } + +#[tokio::test] +async fn test_customize_connection_release() { + #[derive(Debug)] + struct CountingCustomizer { + num_conn_released: Arc, + } + + impl CountingCustomizer { + fn new(num_conn_released: Arc) -> Self { + Self { num_conn_released } + } + } + + #[async_trait] + impl CustomizeConnection for CountingCustomizer { + async fn on_release(&self, _connection: &mut FakeConnection) -> Result<(), E> { + self.num_conn_released.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + } + + #[derive(Debug)] + struct BreakableManager { + _c: PhantomData, + valid: Arc, + broken: Arc, + }; + + impl BreakableManager { + fn new(valid: Arc, broken: Arc) -> Self { + Self { + valid, + broken, + _c: PhantomData, + } + } + } + + #[async_trait] + impl ManageConnection for BreakableManager + where + C: Default + Send + Sync + 'static, + { + type Connection = C; + type Error = Error; + + async fn connect(&self) -> Result { + Ok(Default::default()) + } + + async fn is_valid( + &self, + _conn: &mut PooledConnection<'_, Self>, + ) -> Result<(), Self::Error> { + if self.valid.load(Ordering::SeqCst) { + Ok(()) + } else { + Err(Error) + } + } + + fn has_broken(&self, _: &mut Self::Connection) -> bool { + self.broken.load(Ordering::SeqCst) + } + } + + let valid = Arc::new(AtomicBool::new(true)); + let broken = Arc::new(AtomicBool::new(false)); + let manager = BreakableManager::::new(valid.clone(), broken.clone()); + + let num_conn_released = Arc::new(AtomicUsize::new(0)); + let customizer = CountingCustomizer::new(num_conn_released.clone()); + + let pool = Pool::builder() + .max_size(2) + .connection_customizer(Box::new(customizer)) + .build(manager) + .await + .unwrap(); + + // Connections go in and out of the pool without being released + { + { + let _connection_1 = pool.get().await.unwrap(); + let _connection_2 = pool.get().await.unwrap(); + assert_eq!(num_conn_released.load(Ordering::SeqCst), 0); + } + { + let _connection_1 = pool.get().await.unwrap(); + let _connection_2 = pool.get().await.unwrap(); + assert_eq!(num_conn_released.load(Ordering::SeqCst), 0); + } + } + + // Invalid connections get released + { + valid.store(false, Ordering::SeqCst); + let _connection_1 = pool.get().await.unwrap(); + assert_eq!(num_conn_released.load(Ordering::SeqCst), 2); + let _connection_2 = pool.get().await.unwrap(); + assert_eq!(num_conn_released.load(Ordering::SeqCst), 2); + valid.store(true, Ordering::SeqCst); + } + + // Broken connections get released + { + num_conn_released.store(0, Ordering::SeqCst); + broken.store(true, Ordering::SeqCst); + { + let _connection_1 = pool.get().await.unwrap(); + let _connection_2 = pool.get().await.unwrap(); + assert_eq!(num_conn_released.load(Ordering::SeqCst), 0); + } + sleep(Duration::from_millis(100)).await; + assert_eq!(num_conn_released.load(Ordering::SeqCst), 2); + } +}