Skip to content

Commit

Permalink
Resurrect on_release support from #89 by spawning
Browse files Browse the repository at this point in the history
  • Loading branch information
djc committed Dec 25, 2020
1 parent eadfb32 commit a80460e
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 13 deletions.
19 changes: 12 additions & 7 deletions bb8/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,18 @@ pub trait CustomizeConnection<C: Send + 'static, E: 'static>:
/// Called with connections immediately after they are returned from
/// `ManageConnection::connect`.
///
/// The default implementation simply returns `Ok(())`.
///
/// # Errors
/// The default implementation simply returns `Ok(())`. Any errors will be forwarded to the
/// configured error sink.
async fn on_acquire(&self, _connection: &mut C) -> Result<(), E> {
Ok(())
}

/// Called with connections before they're returned to the connection pool.
///
/// If this method returns an error, the connection will be discarded.
/// The default implementation simply returns `Ok(())`. Any errors will be forwarded to the
/// configured error sink.
#[allow(unused_variables)]
async fn on_acquire(&self, connection: &mut C) -> Result<(), E> {
async fn on_release(&'_ self, _connection: &'_ mut C) -> Result<(), E> {
Ok(())
}
}
Expand All @@ -304,8 +309,8 @@ where
}
}

pub(crate) fn drop_invalid(mut self) {
let _ = self.conn.take();
pub(crate) fn extract(mut self) -> Conn<M::Connection> {
self.conn.take().unwrap()
}
}

Expand Down
24 changes: 23 additions & 1 deletion bb8/src/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ where
match self.inner.manager.is_valid(&mut conn).await {
Ok(()) => return Ok(conn),
Err(_) => {
conn.drop_invalid();
self.on_release_connection(conn.extract());
// Once we've extracted the connection, the `Drop` impl for `PooledConnection`
// will call `put_back(None)`, so we don't need to do anything else here.
continue;
}
}
Expand Down Expand Up @@ -133,6 +135,7 @@ where
if !self.inner.manager.has_broken(&mut conn.conn) {
Some(conn)
} else {
self.on_release_connection(conn);
None
}
});
Expand All @@ -147,6 +150,25 @@ where
}
}

fn on_release_connection(&self, mut conn: Conn<M::Connection>) {
if self.inner.statics.connection_customizer.is_none() {
return;
}

let pool = self.inner.clone();
spawn(async move {
let customizer = match pool.statics.connection_customizer.as_ref() {
Some(customizer) => customizer,
None => return,
};

let future = customizer.on_release(&mut conn.conn);
if let Err(e) = future.await {
pool.statics.error_sink.sink(e);
}
});
}

/// Returns information about the current state of the pool.
pub(crate) fn state(&self) -> State {
self.inner.internals.lock().state()
Expand Down
6 changes: 3 additions & 3 deletions bb8/src/internals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,12 @@ where
}

pub(crate) struct InternalsGuard<M: ManageConnection> {
conn: Option<Conn<M::Connection>>,
pool: Arc<SharedPool<M>>,
pub(crate) conn: Option<Conn<M::Connection>>,
pub(crate) pool: Arc<SharedPool<M>>,
}

impl<M: ManageConnection> InternalsGuard<M> {
fn new(conn: Conn<M::Connection>, pool: Arc<SharedPool<M>>) -> Self {
pub(crate) fn new(conn: Conn<M::Connection>, pool: Arc<SharedPool<M>>) -> Self {
Self {
conn: Some(conn),
pool,
Expand Down
122 changes: 120 additions & 2 deletions bb8/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::iter::FromIterator;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Mutex;
use std::sync::{Arc, Mutex};
use std::task::Poll;
use std::time::Duration;
use std::{error, fmt, mem};
Expand All @@ -14,7 +14,7 @@ use async_trait::async_trait;
use futures_channel::oneshot;
use futures_util::future::{err, lazy, ok, pending, ready, try_join_all, FutureExt};
use futures_util::stream::{FuturesUnordered, TryStreamExt};
use tokio::time::timeout;
use tokio::time::{sleep, timeout};

#[derive(Debug, PartialEq, Eq)]
pub struct Error;
Expand Down Expand Up @@ -786,3 +786,121 @@ async fn test_customize_connection_acquire() {
let connection_1_or_2 = pool.get().await.unwrap();
assert!(connection_1_or_2.custom_field == 1 || connection_1_or_2.custom_field == 2);
}

#[tokio::test]
async fn test_customize_connection_release() {
#[derive(Debug)]
struct CountingCustomizer {
num_conn_released: Arc<AtomicUsize>,
}

impl CountingCustomizer {
fn new(num_conn_released: Arc<AtomicUsize>) -> Self {
Self { num_conn_released }
}
}

#[async_trait]
impl<E: 'static> CustomizeConnection<FakeConnection, E> for CountingCustomizer {
async fn on_release(&self, _connection: &mut FakeConnection) -> Result<(), E> {
self.num_conn_released.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}

#[derive(Debug)]
struct BreakableManager<C> {
_c: PhantomData<C>,
valid: Arc<AtomicBool>,
broken: Arc<AtomicBool>,
};

impl<C> BreakableManager<C> {
fn new(valid: Arc<AtomicBool>, broken: Arc<AtomicBool>) -> Self {
Self {
valid,
broken,
_c: PhantomData,
}
}
}

#[async_trait]
impl<C> ManageConnection for BreakableManager<C>
where
C: Default + Send + Sync + 'static,
{
type Connection = C;
type Error = Error;

async fn connect(&self) -> Result<Self::Connection, Self::Error> {
Ok(Default::default())
}

async fn is_valid(
&self,
_conn: &mut PooledConnection<'_, Self>,
) -> Result<(), Self::Error> {
if self.valid.load(Ordering::SeqCst) {
Ok(())
} else {
Err(Error)
}
}

fn has_broken(&self, _: &mut Self::Connection) -> bool {
self.broken.load(Ordering::SeqCst)
}
}

let valid = Arc::new(AtomicBool::new(true));
let broken = Arc::new(AtomicBool::new(false));
let manager = BreakableManager::<FakeConnection>::new(valid.clone(), broken.clone());

let num_conn_released = Arc::new(AtomicUsize::new(0));
let customizer = CountingCustomizer::new(num_conn_released.clone());

let pool = Pool::builder()
.max_size(2)
.connection_customizer(Box::new(customizer))
.build(manager)
.await
.unwrap();

// Connections go in and out of the pool without being released
{
{
let _connection_1 = pool.get().await.unwrap();
let _connection_2 = pool.get().await.unwrap();
assert_eq!(num_conn_released.load(Ordering::SeqCst), 0);
}
{
let _connection_1 = pool.get().await.unwrap();
let _connection_2 = pool.get().await.unwrap();
assert_eq!(num_conn_released.load(Ordering::SeqCst), 0);
}
}

// Invalid connections get released
{
valid.store(false, Ordering::SeqCst);
let _connection_1 = pool.get().await.unwrap();
assert_eq!(num_conn_released.load(Ordering::SeqCst), 2);
let _connection_2 = pool.get().await.unwrap();
assert_eq!(num_conn_released.load(Ordering::SeqCst), 2);
valid.store(true, Ordering::SeqCst);
}

// Broken connections get released
{
num_conn_released.store(0, Ordering::SeqCst);
broken.store(true, Ordering::SeqCst);
{
let _connection_1 = pool.get().await.unwrap();
let _connection_2 = pool.get().await.unwrap();
assert_eq!(num_conn_released.load(Ordering::SeqCst), 0);
}
sleep(Duration::from_millis(100)).await;
assert_eq!(num_conn_released.load(Ordering::SeqCst), 2);
}
}

0 comments on commit a80460e

Please sign in to comment.