Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync: implement try_recv for mpsc channels #4113

Merged
merged 5 commits into from
Sep 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion tokio/src/sync/mpsc/bounded.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError};
use crate::sync::mpsc::chan;
use crate::sync::mpsc::error::{SendError, TrySendError};
use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError};

cfg_time! {
use crate::sync::mpsc::error::SendTimeoutError;
Expand Down Expand Up @@ -187,6 +187,46 @@ impl<T> Receiver<T> {
poll_fn(|cx| self.chan.recv(cx)).await
}

/// Try to receive the next value for this receiver.
///
/// This method returns the [`Empty`] error if the channel is currently
/// empty, but there are still outstanding [senders] or [permits].
///
/// This method returns the [`Disconnected`] error if the channel is
/// currently empty, and there are no outstanding [senders] or [permits].
///
/// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty
/// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected
/// [senders]: crate::sync::mpsc::Sender
/// [permits]: crate::sync::mpsc::Permit
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
/// use tokio::sync::mpsc::error::TryRecvError;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(100);
///
/// tx.send("hello").await.unwrap();
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
///
/// tx.send("hello").await.unwrap();
/// // Drop the last sender, closing the channel.
/// drop(tx);
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
/// }
/// ```
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.chan.try_recv()
}

/// Blocking receive to call outside of asynchronous contexts.
///
/// This method returns `None` if the channel has been closed and there are
Expand Down
48 changes: 48 additions & 0 deletions tokio/src/sync/mpsc/chan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ use crate::loom::cell::UnsafeCell;
use crate::loom::future::AtomicWaker;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Arc;
use crate::park::thread::CachedParkThread;
use crate::park::Park;
use crate::sync::mpsc::error::TryRecvError;
use crate::sync::mpsc::list;
use crate::sync::notify::Notify;

Expand Down Expand Up @@ -263,6 +266,51 @@ impl<T, S: Semaphore> Rx<T, S> {
}
})
}

/// Try to receive the next value.
pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
use super::list::TryPopResult;

self.inner.rx_fields.with_mut(|rx_fields_ptr| {
let rx_fields = unsafe { &mut *rx_fields_ptr };

macro_rules! try_recv {
() => {
match rx_fields.list.try_pop(&self.inner.tx) {
TryPopResult::Ok(value) => {
self.inner.semaphore.add_permit();
return Ok(value);
}
TryPopResult::Closed => return Err(TryRecvError::Disconnected),
TryPopResult::Empty => return Err(TryRecvError::Empty),
TryPopResult::Busy => {} // fall through
}
};
}

try_recv!();

// If a previous `poll_recv` call has set a waker, we wake it here.
// This allows us to put our own CachedParkThread waker in the
// AtomicWaker slot instead.
//
// This is not a spurious wakeup to `poll_recv` since we just got a
// Busy from `try_pop`, which only happens if there are messages in
// the queue.
self.inner.rx_waker.wake();

// Park the thread until the problematic send has completed.
let mut park = CachedParkThread::new();
let waker = park.unpark().into_waker();
loop {
self.inner.rx_waker.register_by_ref(&waker);
// It is possible that the problematic send has now completed,
// so we have to check for messages again.
try_recv!();
park.park().expect("park failed");
}
})
}
}

impl<T, S: Semaphore> Drop for Rx<T, S> {
Expand Down
24 changes: 24 additions & 0 deletions tokio/src/sync/mpsc/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,30 @@ impl<T> From<SendError<T>> for TrySendError<T> {
}
}

// ===== TryRecvError =====

/// Error returned by `try_recv`.
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub enum TryRecvError {
/// This **channel** is currently empty, but the **Sender**(s) have not yet
/// disconnected, so data may yet become available.
Empty,
/// The **channel**'s sending half has become disconnected, and there will
/// never be any more data received on it.
Disconnected,
}

impl fmt::Display for TryRecvError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
TryRecvError::Empty => "receiving on an empty channel".fmt(fmt),
TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt),
}
}
}

impl Error for TryRecvError {}

// ===== RecvError =====

/// Error returned by `Receiver`.
Expand Down
42 changes: 37 additions & 5 deletions tokio/src/sync/mpsc/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,35 @@ pub(crate) struct Tx<T> {
/// Tail in the `Block` mpmc list.
block_tail: AtomicPtr<Block<T>>,

/// Position to push the next message. This reference a block and offset
/// Position to push the next message. This references a block and offset
/// into the block.
tail_position: AtomicUsize,
}

/// List queue receive handle
pub(crate) struct Rx<T> {
/// Pointer to the block being processed
/// Pointer to the block being processed.
head: NonNull<Block<T>>,

/// Next slot index to process
/// Next slot index to process.
index: usize,

/// Pointer to the next block pending release
/// Pointer to the next block pending release.
free_head: NonNull<Block<T>>,
}

/// Return value of `Rx::try_pop`.
pub(crate) enum TryPopResult<T> {
/// Successfully popped a value.
Ok(T),
/// The channel is empty.
Empty,
/// The channel is empty and closed.
Closed,
/// The channel is not empty, but the first value is being written.
Busy,
}

pub(crate) fn channel<T>() -> (Tx<T>, Rx<T>) {
// Create the initial block shared between the tx and rx halves.
let initial_block = Box::new(Block::new(0));
Expand Down Expand Up @@ -218,7 +230,7 @@ impl<T> fmt::Debug for Tx<T> {
}

impl<T> Rx<T> {
/// Pops the next value off the queue
/// Pops the next value off the queue.
pub(crate) fn pop(&mut self, tx: &Tx<T>) -> Option<block::Read<T>> {
// Advance `head`, if needed
if !self.try_advancing_head() {
Expand All @@ -240,6 +252,26 @@ impl<T> Rx<T> {
}
}

/// Pops the next value off the queue, detecting whether the block
/// is busy or empty on failure.
///
/// This function exists because `Rx::pop` can return `None` even if the
/// channel's queue contains a message that has been completely written.
/// This can happen if the fully delivered message is behind another message
/// that is in the middle of being written to the block, since the channel
/// can't return the messages out of order.
pub(crate) fn try_pop(&mut self, tx: &Tx<T>) -> TryPopResult<T> {
let tail_position = tx.tail_position.load(Acquire);
let result = self.pop(tx);

match result {
Some(block::Read::Value(t)) => TryPopResult::Ok(t),
Some(block::Read::Closed) => TryPopResult::Closed,
None if tail_position == self.index => TryPopResult::Empty,
None => TryPopResult::Busy,
}
}

/// Tries advancing the block pointer to the block referenced by `self.index`.
///
/// Returns `true` if successful, `false` if there is no next block to load.
Expand Down
42 changes: 41 additions & 1 deletion tokio/src/sync/mpsc/unbounded.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::loom::sync::atomic::AtomicUsize;
use crate::sync::mpsc::chan;
use crate::sync::mpsc::error::SendError;
use crate::sync::mpsc::error::{SendError, TryRecvError};

use std::fmt;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -129,6 +129,46 @@ impl<T> UnboundedReceiver<T> {
poll_fn(|cx| self.poll_recv(cx)).await
}

/// Try to receive the next value for this receiver.
///
/// This method returns the [`Empty`] error if the channel is currently
/// empty, but there are still outstanding [senders] or [permits].
///
/// This method returns the [`Disconnected`] error if the channel is
/// currently empty, and there are no outstanding [senders] or [permits].
///
/// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty
/// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected
/// [senders]: crate::sync::mpsc::Sender
/// [permits]: crate::sync::mpsc::Permit
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
/// use tokio::sync::mpsc::error::TryRecvError;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::unbounded_channel();
///
/// tx.send("hello").unwrap();
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
///
/// tx.send("hello").unwrap();
/// // Drop the last sender, closing the channel.
/// drop(tx);
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
/// }
/// ```
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.chan.try_recv()
}

/// Blocking receive to call outside of asynchronous contexts.
///
/// # Panics
Expand Down
56 changes: 56 additions & 0 deletions tokio/src/sync/tests/loom_mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,59 @@ fn dropping_unbounded_tx() {
assert!(v.is_none());
});
}

#[test]
fn try_recv() {
loom::model(|| {
use crate::sync::{mpsc, Semaphore};
use loom::sync::{Arc, Mutex};

const PERMITS: usize = 2;
const TASKS: usize = 2;
const CYCLES: usize = 1;

struct Context {
sem: Arc<Semaphore>,
tx: mpsc::Sender<()>,
rx: Mutex<mpsc::Receiver<()>>,
}

fn run(ctx: &Context) {
block_on(async {
let permit = ctx.sem.acquire().await;
assert_ok!(ctx.rx.lock().unwrap().try_recv());
crate::task::yield_now().await;
assert_ok!(ctx.tx.clone().try_send(()));
drop(permit);
});
}

let (tx, rx) = mpsc::channel(PERMITS);
let sem = Arc::new(Semaphore::new(PERMITS));
let ctx = Arc::new(Context {
sem,
tx,
rx: Mutex::new(rx),
});

for _ in 0..PERMITS {
assert_ok!(ctx.tx.clone().try_send(()));
}

let mut ths = Vec::new();

for _ in 0..TASKS {
let ctx = ctx.clone();

ths.push(thread::spawn(move || {
run(&ctx);
}));
}

run(&ctx);

for th in ths {
th.join().unwrap();
}
});
}
Loading