From c60b00517795e5b4bfa9fb431664d67538baa248 Mon Sep 17 00:00:00 2001 From: John-John Tedro Date: Mon, 30 Dec 2019 18:27:08 +0100 Subject: [PATCH] sync: Add async mpsc::Sender::shared_send method This send variant creates a new permit on-the-fly to avoid taking a unique reference. It's useful for instances where the Sender is embedded into a larger struct which itself is shared across an application. The alternative is to clone the entire sender, which would involve cloning the inner `Arc>` nedlessly. Using `shared_send` instead does the minimal amount of work possible with the current algorithm. --- tokio/src/sync/mpsc/bounded.rs | 84 ++++++++++++++++++++++++++++++++++ tokio/src/sync/mpsc/chan.rs | 24 ++++++++-- 2 files changed, 104 insertions(+), 4 deletions(-) diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index da3bd6381de..0a15b3fea02 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -321,4 +321,88 @@ impl Sender { Err(TrySendError::Closed(value)) => Err(SendError(value)), } } + + /// Send a value through a shared reference, waiting until there is + /// capacity. + /// + /// Behaves exactly like `send`, except that it performs some extra work to + /// synchronize through a shared reference. This is accomplished by creating + /// a new permit for every value sent over the channel. If possible, you + /// should prefer to use `send` since it has lower overhead. + /// + /// See [`send`] for more documentation. + /// + /// [`send`]: Sender::send + /// + /// # Examples + /// + /// In the following example, each call to `shared_send` will block until the + /// previously sent value was received. + /// + /// ```rust + /// use tokio::sync::mpsc; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// { + /// let tx = Arc::new(tx); + /// + /// for i in 0..10i32 { + /// let tx = Arc::clone(&tx); + /// + /// let _ = tokio::spawn(async move { + /// if let Err(_) = tx.shared_send(i).await { + /// println!("receiver dropped"); + /// } + /// }); + /// } + /// } + /// + /// while let Some(i) = rx.recv().await { + /// println!("got = {}", i); + /// } + /// } + /// ``` + pub async fn shared_send(&self, value: T) -> Result<(), SendError> { + use self::chan::Semaphore as _; + use crate::future::poll_fn; + + let mut permit = Semaphore::new_permit(); + let permit = Guard(&self.chan.inner.semaphore, &mut permit); + + if poll_fn(|cx| self.chan.poll_ready_with_permit(cx, permit.1)) + .await + .is_err() + { + return Err(SendError(value)); + } + + return match self + .chan + .try_send_with_permit(value, permit.1) + .map_err(TrySendError::::from) + { + Ok(()) => Ok(()), + Err(TrySendError::Full(_)) => unreachable!(), + Err(TrySendError::Closed(value)) => Err(SendError(value)), + }; + + // A permit guard, making sure that the drop implementation is run as + // appropriate. + struct Guard<'a, S>(&'a S, &'a mut S::Permit) + where + S: chan::Semaphore; + + impl Drop for Guard<'_, S> + where + S: chan::Semaphore, + { + fn drop(&mut self) { + self.0.drop_permit(self.1); + } + } + } } diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 7a15e8b3adc..797379cb961 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -13,7 +13,7 @@ use std::task::{Context, Poll}; /// Channel sender pub(crate) struct Tx { - inner: Arc>, + pub(crate) inner: Arc>, permit: S::Permit, } @@ -91,19 +91,19 @@ pub(crate) trait Semaphore { /// A value was sent into the channel and the permit held by `tx` is /// dropped. In this case, the permit should not immeditely be returned to - /// the semaphore. Instead, the permit is returnred to the semaphore once + /// the semaphore. Instead, the permit is returned to the semaphore once /// the sent value is read by the rx handle. fn forget(&self, permit: &mut Self::Permit); fn close(&self); } -struct Chan { +pub(crate) struct Chan { /// Handle to the push half of the lock-free list. tx: list::Tx, /// Coordinates access to channel's capacity. - semaphore: S, + pub(crate) semaphore: S, /// Receiver waker. Notified when a value is pushed into the channel. rx_waker: AtomicWaker, @@ -190,10 +190,26 @@ where self.inner.semaphore.poll_acquire(cx, &mut self.permit) } + pub(crate) fn poll_ready_with_permit( + &self, + cx: &mut Context<'_>, + permit: &mut S::Permit, + ) -> Poll> { + self.inner.semaphore.poll_acquire(cx, permit) + } + /// Send a message and notify the receiver. pub(crate) fn try_send(&mut self, value: T) -> Result<(), (T, TrySendError)> { self.inner.try_send(value, &mut self.permit) } + + pub(crate) fn try_send_with_permit( + &self, + value: T, + permit: &mut S::Permit, + ) -> Result<(), (T, TrySendError)> { + self.inner.try_send(value, permit) + } } impl Tx {