diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index 6c99af089bf..d8fc0dcf2b9 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -1,6 +1,6 @@ use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError}; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::{SendError, TrySendError}; +use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; cfg_time! { use crate::sync::mpsc::error::SendTimeoutError; @@ -187,6 +187,46 @@ impl Receiver { poll_fn(|cx| self.chan.recv(cx)).await } + /// Try to receive the next value for this receiver. + /// + /// This method returns the [`Empty`] error if the channel is currently + /// empty, but there are still outstanding [senders] or [permits]. + /// + /// This method returns the [`Disconnected`] error if the channel is + /// currently empty, and there are no outstanding [senders] or [permits]. + /// + /// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty + /// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected + /// [senders]: crate::sync::mpsc::Sender + /// [permits]: crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use tokio::sync::mpsc::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(100); + /// + /// tx.send("hello").await.unwrap(); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + /// + /// tx.send("hello").await.unwrap(); + /// // Drop the last sender, closing the channel. + /// drop(tx); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result { + self.chan.try_recv() + } + /// Blocking receive to call outside of asynchronous contexts. /// /// This method returns `None` if the channel has been closed and there are diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 554d0228478..637ae1fd16f 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -2,6 +2,9 @@ use crate::loom::cell::UnsafeCell; use crate::loom::future::AtomicWaker; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Arc; +use crate::park::thread::CachedParkThread; +use crate::park::Park; +use crate::sync::mpsc::error::TryRecvError; use crate::sync::mpsc::list; use crate::sync::notify::Notify; @@ -263,6 +266,51 @@ impl Rx { } }) } + + /// Try to receive the next value. + pub(crate) fn try_recv(&mut self) -> Result { + use super::list::TryPopResult; + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + macro_rules! try_recv { + () => { + match rx_fields.list.try_pop(&self.inner.tx) { + TryPopResult::Ok(value) => { + self.inner.semaphore.add_permit(); + return Ok(value); + } + TryPopResult::Closed => return Err(TryRecvError::Disconnected), + TryPopResult::Empty => return Err(TryRecvError::Empty), + TryPopResult::Busy => {} // fall through + } + }; + } + + try_recv!(); + + // If a previous `poll_recv` call has set a waker, we wake it here. + // This allows us to put our own CachedParkThread waker in the + // AtomicWaker slot instead. + // + // This is not a spurious wakeup to `poll_recv` since we just got a + // Busy from `try_pop`, which only happens if there are messages in + // the queue. + self.inner.rx_waker.wake(); + + // Park the thread until the problematic send has completed. + let mut park = CachedParkThread::new(); + let waker = park.unpark().into_waker(); + loop { + self.inner.rx_waker.register_by_ref(&waker); + // It is possible that the problematic send has now completed, + // so we have to check for messages again. + try_recv!(); + park.park().expect("park failed"); + } + }) + } } impl Drop for Rx { diff --git a/tokio/src/sync/mpsc/error.rs b/tokio/src/sync/mpsc/error.rs index 0d25ad3869b..48ca379d9ca 100644 --- a/tokio/src/sync/mpsc/error.rs +++ b/tokio/src/sync/mpsc/error.rs @@ -51,6 +51,30 @@ impl From> for TrySendError { } } +// ===== TryRecvError ===== + +/// Error returned by `try_recv`. +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum TryRecvError { + /// This **channel** is currently empty, but the **Sender**(s) have not yet + /// disconnected, so data may yet become available. + Empty, + /// The **channel**'s sending half has become disconnected, and there will + /// never be any more data received on it. + Disconnected, +} + +impl fmt::Display for TryRecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + TryRecvError::Empty => "receiving on an empty channel".fmt(fmt), + TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt), + } + } +} + +impl Error for TryRecvError {} + // ===== RecvError ===== /// Error returned by `Receiver`. diff --git a/tokio/src/sync/mpsc/list.rs b/tokio/src/sync/mpsc/list.rs index 5dad2babfaf..53c34d2fa6f 100644 --- a/tokio/src/sync/mpsc/list.rs +++ b/tokio/src/sync/mpsc/list.rs @@ -13,23 +13,35 @@ pub(crate) struct Tx { /// Tail in the `Block` mpmc list. block_tail: AtomicPtr>, - /// Position to push the next message. This reference a block and offset + /// Position to push the next message. This references a block and offset /// into the block. tail_position: AtomicUsize, } /// List queue receive handle pub(crate) struct Rx { - /// Pointer to the block being processed + /// Pointer to the block being processed. head: NonNull>, - /// Next slot index to process + /// Next slot index to process. index: usize, - /// Pointer to the next block pending release + /// Pointer to the next block pending release. free_head: NonNull>, } +/// Return value of `Rx::try_pop`. +pub(crate) enum TryPopResult { + /// Successfully popped a value. + Ok(T), + /// The channel is empty. + Empty, + /// The channel is empty and closed. + Closed, + /// The channel is not empty, but the first value is being written. + Busy, +} + pub(crate) fn channel() -> (Tx, Rx) { // Create the initial block shared between the tx and rx halves. let initial_block = Box::new(Block::new(0)); @@ -218,7 +230,7 @@ impl fmt::Debug for Tx { } impl Rx { - /// Pops the next value off the queue + /// Pops the next value off the queue. pub(crate) fn pop(&mut self, tx: &Tx) -> Option> { // Advance `head`, if needed if !self.try_advancing_head() { @@ -240,6 +252,26 @@ impl Rx { } } + /// Pops the next value off the queue, detecting whether the block + /// is busy or empty on failure. + /// + /// This function exists because `Rx::pop` can return `None` even if the + /// channel's queue contains a message that has been completely written. + /// This can happen if the fully delivered message is behind another message + /// that is in the middle of being written to the block, since the channel + /// can't return the messages out of order. + pub(crate) fn try_pop(&mut self, tx: &Tx) -> TryPopResult { + let tail_position = tx.tail_position.load(Acquire); + let result = self.pop(tx); + + match result { + Some(block::Read::Value(t)) => TryPopResult::Ok(t), + Some(block::Read::Closed) => TryPopResult::Closed, + None if tail_position == self.index => TryPopResult::Empty, + None => TryPopResult::Busy, + } + } + /// Tries advancing the block pointer to the block referenced by `self.index`. /// /// Returns `true` if successful, `false` if there is no next block to load. diff --git a/tokio/src/sync/mpsc/unbounded.rs b/tokio/src/sync/mpsc/unbounded.rs index 23c80f60a35..82ba615f55f 100644 --- a/tokio/src/sync/mpsc/unbounded.rs +++ b/tokio/src/sync/mpsc/unbounded.rs @@ -1,6 +1,6 @@ use crate::loom::sync::atomic::AtomicUsize; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::SendError; +use crate::sync::mpsc::error::{SendError, TryRecvError}; use std::fmt; use std::task::{Context, Poll}; @@ -129,6 +129,46 @@ impl UnboundedReceiver { poll_fn(|cx| self.poll_recv(cx)).await } + /// Try to receive the next value for this receiver. + /// + /// This method returns the [`Empty`] error if the channel is currently + /// empty, but there are still outstanding [senders] or [permits]. + /// + /// This method returns the [`Disconnected`] error if the channel is + /// currently empty, and there are no outstanding [senders] or [permits]. + /// + /// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty + /// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected + /// [senders]: crate::sync::mpsc::Sender + /// [permits]: crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use tokio::sync::mpsc::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// + /// tx.send("hello").unwrap(); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + /// + /// tx.send("hello").unwrap(); + /// // Drop the last sender, closing the channel. + /// drop(tx); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result { + self.chan.try_recv() + } + /// Blocking receive to call outside of asynchronous contexts. /// /// # Panics diff --git a/tokio/src/sync/tests/loom_mpsc.rs b/tokio/src/sync/tests/loom_mpsc.rs index c12313bd365..f165e7076e7 100644 --- a/tokio/src/sync/tests/loom_mpsc.rs +++ b/tokio/src/sync/tests/loom_mpsc.rs @@ -132,3 +132,59 @@ fn dropping_unbounded_tx() { assert!(v.is_none()); }); } + +#[test] +fn try_recv() { + loom::model(|| { + use crate::sync::{mpsc, Semaphore}; + use loom::sync::{Arc, Mutex}; + + const PERMITS: usize = 2; + const TASKS: usize = 2; + const CYCLES: usize = 1; + + struct Context { + sem: Arc, + tx: mpsc::Sender<()>, + rx: Mutex>, + } + + fn run(ctx: &Context) { + block_on(async { + let permit = ctx.sem.acquire().await; + assert_ok!(ctx.rx.lock().unwrap().try_recv()); + crate::task::yield_now().await; + assert_ok!(ctx.tx.clone().try_send(())); + drop(permit); + }); + } + + let (tx, rx) = mpsc::channel(PERMITS); + let sem = Arc::new(Semaphore::new(PERMITS)); + let ctx = Arc::new(Context { + sem, + tx, + rx: Mutex::new(rx), + }); + + for _ in 0..PERMITS { + assert_ok!(ctx.tx.clone().try_send(())); + } + + let mut ths = Vec::new(); + + for _ in 0..TASKS { + let ctx = ctx.clone(); + + ths.push(thread::spawn(move || { + run(&ctx); + })); + } + + run(&ctx); + + for th in ths { + th.join().unwrap(); + } + }); +} diff --git a/tokio/tests/sync_mpsc.rs b/tokio/tests/sync_mpsc.rs index cd43ad4381c..2ac743be261 100644 --- a/tokio/tests/sync_mpsc.rs +++ b/tokio/tests/sync_mpsc.rs @@ -5,7 +5,7 @@ use std::thread; use tokio::runtime::Runtime; use tokio::sync::mpsc; -use tokio::sync::mpsc::error::TrySendError; +use tokio::sync::mpsc::error::{TryRecvError, TrySendError}; use tokio_test::task; use tokio_test::{ assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok, @@ -327,6 +327,27 @@ async fn try_send_fail() { assert!(rx.recv().await.is_none()); } +#[tokio::test] +async fn try_send_fail_with_try_recv() { + let (tx, mut rx) = mpsc::channel(1); + + tx.try_send("hello").unwrap(); + + // This should fail + match assert_err!(tx.try_send("fail")) { + TrySendError::Full(..) => {} + _ => panic!(), + } + + assert_eq!(rx.try_recv(), Ok("hello")); + + assert_ok!(tx.try_send("goodbye")); + drop(tx); + + assert_eq!(rx.try_recv(), Ok("goodbye")); + assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected)); +} + #[tokio::test] async fn try_reserve_fails() { let (tx, mut rx) = mpsc::channel(1); @@ -494,3 +515,83 @@ async fn permit_available_not_acquired_close() { drop(permit2); assert!(rx.recv().await.is_none()); } + +#[test] +fn try_recv_bounded() { + let (tx, mut rx) = mpsc::channel(5); + + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + assert!(tx.try_send("hello").is_err()); + + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + assert_eq!(Ok("hello"), rx.try_recv()); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + assert!(tx.try_send("hello").is_err()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + drop(tx); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); +} + +#[test] +fn try_recv_unbounded() { + for num in 0..100 { + let (tx, mut rx) = mpsc::unbounded_channel(); + + for i in 0..num { + tx.send(i).unwrap(); + } + + for i in 0..num { + assert_eq!(rx.try_recv(), Ok(i)); + } + + assert_eq!(rx.try_recv(), Err(TryRecvError::Empty)); + drop(tx); + assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected)); + } +} + +#[test] +fn try_recv_close_while_empty_bounded() { + let (tx, mut rx) = mpsc::channel::<()>(5); + + assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + drop(tx); + assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); +} + +#[test] +fn try_recv_close_while_empty_unbounded() { + let (tx, mut rx) = mpsc::unbounded_channel::<()>(); + + assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + drop(tx); + assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); +}