Skip to content

Commit

Permalink
Merge pull request #271 from espindola/fix-future-bug
Browse files Browse the repository at this point in the history
Use an explicit priority check
  • Loading branch information
blackbeam authored Dec 17, 2023
2 parents 514d6db + be693e0 commit 6dc09a7
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 29 deletions.
4 changes: 1 addition & 3 deletions src/conn/pool/futures/get_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,8 @@ impl Future for GetConn {
loop {
match self.inner {
GetConnInner::New => {
let queued = self.queue_id.is_some();
let queue_id = *self.queue_id.get_or_insert_with(QueueId::next);
let next =
ready!(Pin::new(self.pool_mut()).poll_new_conn(cx, queued, queue_id))?;
let next = ready!(Pin::new(self.pool_mut()).poll_new_conn(cx, queue_id))?;
match next {
GetConnInner::Connecting(conn_fut) => {
self.inner = GetConnInner::Connecting(conn_fut);
Expand Down
115 changes: 89 additions & 26 deletions src/conn/pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ impl Waitlist {
self.queue.remove(&tmp);
}

fn is_empty(&self) -> bool {
self.queue.is_empty()
fn peek_id(&mut self) -> Option<QueueId> {
self.queue.peek().map(|(qw, _)| qw.queue_id)
}
}

Expand Down Expand Up @@ -303,16 +303,14 @@ impl Pool {
fn poll_new_conn(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
queued: bool,
queue_id: QueueId,
) -> Poll<Result<GetConnInner>> {
self.poll_new_conn_inner(cx, queued, queue_id)
self.poll_new_conn_inner(cx, queue_id)
}

fn poll_new_conn_inner(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
queued: bool,
queue_id: QueueId,
) -> Poll<Result<GetConnInner>> {
let mut exchange = self.inner.exchange.lock().unwrap();
Expand All @@ -326,8 +324,15 @@ impl Pool {

exchange.spawn_futures_if_needed(&self.inner);

// Check if others are waiting and we're not queued.
if !exchange.waiting.is_empty() && !queued {
// Check if we are higher priority than anything current
let highest = if let Some(cur) = exchange.waiting.peek_id() {
queue_id > cur
} else {
true
};

// If we are not, just queue
if !highest {
exchange.waiting.push(cx.waker().clone(), queue_id);
return Poll::Pending;
}
Expand Down Expand Up @@ -392,14 +397,14 @@ impl Drop for Conn {
#[cfg(test)]
mod test {
use futures_util::{
future::{join_all, select, select_all, try_join_all},
try_join, FutureExt,
future::{join_all, select, select_all, try_join_all, Either},
poll, try_join, FutureExt,
};
use tokio::time::{sleep, timeout};

use std::{
cmp::Reverse,
task::{RawWaker, RawWakerVTable, Waker},
task::{Poll, RawWaker, RawWakerVTable, Waker},
time::Duration,
};

Expand All @@ -423,6 +428,12 @@ mod test {
};
}

fn pool_with_one_connection() -> Pool {
let pool_opts = PoolOpts::new().with_constraints(PoolConstraints::new(1, 1).unwrap());
let opts = get_opts().pool_opts(pool_opts.clone());
Pool::new(opts)
}

#[tokio::test]
async fn should_opt_out_of_connection_reset() -> super::Result<()> {
let pool_opts = PoolOpts::new().with_constraints(PoolConstraints::new(1, 1).unwrap());
Expand Down Expand Up @@ -571,10 +582,7 @@ mod test {

#[tokio::test]
async fn should_reuse_connections() -> super::Result<()> {
let constraints = PoolConstraints::new(1, 1).unwrap();
let opts = get_opts().pool_opts(PoolOpts::default().with_constraints(constraints));

let pool = Pool::new(opts);
let pool = pool_with_one_connection();
let mut conn = pool.get_conn().await?;

let server_version = conn.server_version();
Expand Down Expand Up @@ -613,10 +621,7 @@ mod test {

#[tokio::test]
async fn should_start_transaction() -> super::Result<()> {
let constraints = PoolConstraints::new(1, 1).unwrap();
let opts = get_opts().pool_opts(PoolOpts::default().with_constraints(constraints));

let pool = Pool::new(opts);
let pool = pool_with_one_connection();

"CREATE TABLE IF NOT EXISTS mysql.tmp(id int)"
.ignore(&pool)
Expand Down Expand Up @@ -909,10 +914,7 @@ mod test {

#[tokio::test]
async fn should_ignore_non_fatal_errors_while_returning_to_a_pool() -> super::Result<()> {
let pool_constraints = PoolConstraints::new(1, 1).unwrap();
let pool_opts = PoolOpts::default().with_constraints(pool_constraints);

let pool = Pool::new(get_opts().pool_opts(pool_opts));
let pool = pool_with_one_connection();
let id = pool.get_conn().await?.id();

// non-fatal errors are ignored
Expand All @@ -927,10 +929,7 @@ mod test {

#[tokio::test]
async fn should_remove_waker_of_cancelled_task() {
let pool_constraints = PoolConstraints::new(1, 1).unwrap();
let pool_opts = PoolOpts::default().with_constraints(pool_constraints);

let pool = Pool::new(get_opts().pool_opts(pool_opts));
let pool = pool_with_one_connection();
let only_conn = pool.get_conn().await.unwrap();

let join_handle = tokio::spawn(timeout(Duration::from_secs(1), pool.get_conn()));
Expand Down Expand Up @@ -1059,6 +1058,70 @@ mod test {
Ok(())
}

#[tokio::test]
async fn check_priorities() -> super::Result<()> {
let pool = pool_with_one_connection();

let queue_len = || {
let exchange = pool.inner.exchange.lock().unwrap();
exchange.waiting.queue.len()
};

// Get a connection, so we know the next futures will be
// queued.
let conn = pool.get_conn().await.unwrap();

let get_pending = || async {
let fut = async {
pool.get_conn().await.unwrap();
}
.shared();
let p = poll!(fut.clone());
assert!(matches!(p, Poll::Pending));
fut
};

let fut1 = get_pending().await;
let fut2 = get_pending().await;

// Both futures are queued
assert_eq!(queue_len(), 2);

drop(conn); // This will pop fut1 from the queue, making it [2]
while queue_len() != 1 {
tokio::time::sleep(Duration::from_millis(100)).await;
}

// We called wake on fut1, and even with the select fut1 will
// resolve first
let Either::Right((_, fut2)) = select(fut2, fut1).await else {
panic!("wrong future");
};

// We dropped the connection of fut1, but very likely hasn't
// made it through the recycler yet.
assert_eq!(queue_len(), 1);

let p = poll!(fut2.clone());
assert!(matches!(p, Poll::Pending));
assert_eq!(queue_len(), 1); // The queue still has fut2

// The connection will pass by the recycler and unblock fut2
// and pop it from the queue.
fut2.await;
assert_eq!(queue_len(), 0);

// The recycler is probably not done, so a new future will be
// pending.
let fut3 = get_pending().await;
assert_eq!(queue_len(), 1);

// It is OK to await it.
fut3.await;

Ok(())
}

#[cfg(feature = "nightly")]
mod bench {
use futures_util::future::{FutureExt, TryFutureExt};
Expand Down

0 comments on commit 6dc09a7

Please sign in to comment.