Skip to content

Commit

Permalink
Remove waker from wait list when GetConn is dropped
Browse files Browse the repository at this point in the history
* Replace the BinaryHeap with PriorityQueue, which supports
  removals.
* Refactor the wait list into a separate type to isolate the
  operations.
  • Loading branch information
cloneable committed Dec 2, 2022
1 parent a51930f commit dc78d2d
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 34 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ once_cell = "1.7.2"
pem = "1.0.1"
percent-encoding = "2.1.0"
pin-project = "1.0.2"
priority-queue = "1"
serde = "1"
serde_json = "1"
socket2 = "0.4.2"
Expand Down
6 changes: 2 additions & 4 deletions src/conn/pool/futures/disconnect_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use futures_core::ready;
use tokio::sync::mpsc::UnboundedSender;

use crate::{
conn::pool::{Inner, Pool, QueuedWaker, QUEUE_END_ID},
conn::pool::{Inner, Pool, QUEUE_END_ID},
error::Error,
Conn,
};
Expand Down Expand Up @@ -50,9 +50,7 @@ impl Future for DisconnectPool {
self.pool_inner.close.store(true, atomic::Ordering::Release);
let mut exchange = self.pool_inner.exchange.lock().unwrap();
exchange.spawn_futures_if_needed(&self.pool_inner);
exchange
.waiting
.push(QueuedWaker::new(QUEUE_END_ID, cx.waker().clone()));
exchange.waiting.push(cx.waker().clone(), QUEUE_END_ID);
drop(exchange);

if self.pool_inner.closed.load(atomic::Ordering::Acquire) {
Expand Down
5 changes: 5 additions & 0 deletions src/conn/pool/futures/get_conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ impl Drop for GetConn {
// We drop a connection before it can be resolved, a.k.a. cancelling it.
// Make sure we maintain the necessary invariants towards the pool.
if let Some(pool) = self.pool.take() {
// Remove the waker from the pool's waitlist in case this task was
// woken by another waker, like from tokio::time::timeout.
if let Some(queue_id) = self.queue_id {
pool.unqueue(queue_id);
}
if let GetConnInner::Connecting(..) = self.inner.take() {
pool.cancel_connection();
}
Expand Down
150 changes: 126 additions & 24 deletions src/conn/pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
// modified, or distributed except according to those terms.

use futures_util::FutureExt;
use priority_queue::PriorityQueue;
use tokio::sync::mpsc;

use std::{
cmp::{Ordering, Reverse},
collections::{BinaryHeap, VecDeque},
collections::VecDeque,
convert::TryFrom,
hash::{Hash, Hasher},
pin::Pin,
str::FromStr,
sync::{atomic, Arc, Mutex},
Expand Down Expand Up @@ -63,7 +65,7 @@ impl From<Conn> for IdlingConn {
/// This is fine as long as we never do expensive work while holding the lock!
#[derive(Debug)]
struct Exchange {
waiting: BinaryHeap<QueuedWaker>,
waiting: Waitlist,
available: VecDeque<IdlingConn>,
exist: usize,
// only used to spawn the recycler the first time we're in async context
Expand All @@ -88,9 +90,45 @@ impl Exchange {
}
}

#[derive(Default, Debug)]
struct Waitlist {
queue: PriorityQueue<QueuedWaker, QueueId>,
}

impl Waitlist {
fn push(&mut self, w: Waker, queue_id: QueueId) {
self.queue.push(
QueuedWaker {
queue_id,
waker: Some(w),
},
queue_id,
);
}

fn pop(&mut self) -> Option<Waker> {
match self.queue.pop() {
Some((qw, _)) => Some(qw.waker.unwrap()),
None => None,
}
}

fn remove(&mut self, id: QueueId) {
let tmp = QueuedWaker {
queue_id: id,
waker: None,
};
self.queue.remove(&tmp);
}

fn is_empty(&self) -> bool {
self.queue.is_empty()
}
}

const QUEUE_END_ID: QueueId = QueueId(Reverse(u64::MAX));

#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub(crate) struct QueueId(Reverse<u64>);

impl QueueId {
Expand All @@ -104,13 +142,7 @@ impl QueueId {
#[derive(Debug)]
struct QueuedWaker {
queue_id: QueueId,
waker: Waker,
}

impl QueuedWaker {
fn new(queue_id: QueueId, waker: Waker) -> Self {
QueuedWaker { queue_id, waker }
}
waker: Option<Waker>,
}

impl Eq for QueuedWaker {}
Expand All @@ -133,6 +165,12 @@ impl PartialOrd for QueuedWaker {
}
}

impl Hash for QueuedWaker {
fn hash<H: Hasher>(&self, state: &mut H) {
self.queue_id.hash(state)
}
}

/// Connection pool data.
#[derive(Debug)]
pub struct Inner {
Expand Down Expand Up @@ -177,7 +215,7 @@ impl Pool {
closed: false.into(),
exchange: Mutex::new(Exchange {
available: VecDeque::with_capacity(pool_opts.constraints().max()),
waiting: BinaryHeap::new(),
waiting: Waitlist::default(),
exist: 0,
recycler: Some((rx, pool_opts)),
}),
Expand Down Expand Up @@ -227,8 +265,8 @@ impl Pool {
let mut exchange = self.inner.exchange.lock().unwrap();
if exchange.available.len() < self.opts.pool_opts().active_bound() {
exchange.available.push_back(conn.into());
if let Some(qw) = exchange.waiting.pop() {
qw.waker.wake();
if let Some(w) = exchange.waiting.pop() {
w.wake();
}
return;
}
Expand Down Expand Up @@ -262,8 +300,8 @@ impl Pool {
let mut exchange = self.inner.exchange.lock().unwrap();
exchange.exist -= 1;
// we just enabled the creation of a new connection!
if let Some(qw) = exchange.waiting.pop() {
qw.waker.wake();
if let Some(w) = exchange.waiting.pop() {
w.wake();
}
}

Expand Down Expand Up @@ -296,9 +334,7 @@ impl Pool {

// Check if others are waiting and we're not queued.
if !exchange.waiting.is_empty() && !queued {
exchange
.waiting
.push(QueuedWaker::new(queue_id, cx.waker().clone()));
exchange.waiting.push(cx.waker().clone(), queue_id);
return Poll::Pending;
}

Expand Down Expand Up @@ -328,11 +364,14 @@ impl Pool {
}

// Polled, but no conn available? Back into the queue.
exchange
.waiting
.push(QueuedWaker::new(queue_id, cx.waker().clone()));
exchange.waiting.push(cx.waker().clone(), queue_id);
Poll::Pending
}

fn unqueue(&self, queue_id: QueueId) {
let mut exchange = self.inner.exchange.lock().unwrap();
exchange.waiting.remove(queue_id);
}
}

impl Drop for Conn {
Expand Down Expand Up @@ -363,12 +402,20 @@ mod test {
try_join, FutureExt,
};
use mysql_common::row::Row;
use tokio::time::sleep;
use tokio::time::{sleep, timeout};

use std::time::Duration;
use std::{
cmp::Reverse,
task::{RawWaker, RawWakerVTable, Waker},
time::Duration,
};

use crate::{
conn::pool::Pool, opts::PoolOpts, prelude::*, test_misc::get_opts, PoolConstraints, TxOpts,
conn::pool::{Pool, QueueId, Waitlist, QUEUE_END_ID},
opts::PoolOpts,
prelude::*,
test_misc::get_opts,
PoolConstraints, TxOpts,
};

macro_rules! conn_ex_field {
Expand Down Expand Up @@ -824,6 +871,27 @@ mod test {
Ok(())
}

#[tokio::test]
async fn should_remove_waker_of_cancelled_task() {
let pool_constraints = PoolConstraints::new(1, 1).unwrap();
let pool_opts = PoolOpts::default().with_constraints(pool_constraints);

let pool = Pool::new(get_opts().pool_opts(pool_opts));
let only_conn = pool.get_conn().await.unwrap();

let join_handle = tokio::spawn(timeout(Duration::from_secs(1), pool.get_conn()));

sleep(Duration::from_secs(2)).await;

match join_handle.await.unwrap() {
Err(_elapsed) => (),
_ => panic!("unexpected Ok()"),
}
drop(only_conn);

assert_eq!(0, pool.inner.exchange.lock().unwrap().waiting.queue.len());
}

#[tokio::test]
async fn should_work_if_pooled_connection_operation_is_cancelled() -> super::Result<()> {
let pool = Pool::new(get_opts());
Expand Down Expand Up @@ -868,6 +936,40 @@ mod test {
Ok(())
}

#[test]
fn waitlist_integrity() {
const DATA: *const () = &();
const NOOP_CLONE_FN: unsafe fn(*const ()) -> RawWaker = |_| RawWaker::new(DATA, &RW_VTABLE);
const NOOP_FN: unsafe fn(*const ()) = |_| {};
static RW_VTABLE: RawWakerVTable =
RawWakerVTable::new(NOOP_CLONE_FN, NOOP_FN, NOOP_FN, NOOP_FN);
let w = unsafe { Waker::from_raw(RawWaker::new(DATA, &RW_VTABLE)) };

let mut waitlist = Waitlist::default();
assert_eq!(0, waitlist.queue.len());

waitlist.push(w.clone(), QueueId(Reverse(4)));
waitlist.push(w.clone(), QueueId(Reverse(2)));
waitlist.push(w.clone(), QueueId(Reverse(8)));
waitlist.push(w.clone(), QUEUE_END_ID);
waitlist.push(w.clone(), QueueId(Reverse(10)));

waitlist.remove(QueueId(Reverse(8)));

assert_eq!(4, waitlist.queue.len());

let (_, id) = waitlist.queue.pop().unwrap();
assert_eq!(2, id.0 .0);
let (_, id) = waitlist.queue.pop().unwrap();
assert_eq!(4, id.0 .0);
let (_, id) = waitlist.queue.pop().unwrap();
assert_eq!(10, id.0 .0);
let (_, id) = waitlist.queue.pop().unwrap();
assert_eq!(QUEUE_END_ID, id);

assert_eq!(0, waitlist.queue.len());
}

#[cfg(feature = "nightly")]
mod bench {
use futures_util::future::{FutureExt, TryFutureExt};
Expand Down
12 changes: 6 additions & 6 deletions src/conn/pool/recycler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ impl Future for Recycler {
$self.discard.push($conn.close_conn().boxed());
} else {
exchange.available.push_back($conn.into());
if let Some(qw) = exchange.waiting.pop() {
qw.waker.wake();
if let Some(w) = exchange.waiting.pop() {
w.wake();
}
}
}
Expand Down Expand Up @@ -163,8 +163,8 @@ impl Future for Recycler {
let mut exchange = self.inner.exchange.lock().unwrap();
exchange.exist -= self.discarded;
for _ in 0..self.discarded {
if let Some(qw) = exchange.waiting.pop() {
qw.waker.wake();
if let Some(w) = exchange.waiting.pop() {
w.wake();
}
}
drop(exchange);
Expand Down Expand Up @@ -197,8 +197,8 @@ impl Future for Recycler {
if self.inner.closed.load(Ordering::Acquire) {
// `DisconnectPool` might still wait to be woken up.
let mut exchange = self.inner.exchange.lock().unwrap();
while let Some(qw) = exchange.waiting.pop() {
qw.waker.wake();
while let Some(w) = exchange.waiting.pop() {
w.wake();
}
// we're about to exit, so there better be no outstanding connections
assert_eq!(exchange.exist, 0);
Expand Down

0 comments on commit dc78d2d

Please sign in to comment.