Skip to content

Commit 64d3cbf

Browse files
committed
fix(core): avoid unnecessary wakeups in try_stream!()
fixes #2834
1 parent 82fadce commit 64d3cbf

File tree

1 file changed

+92
-29
lines changed

1 file changed

+92
-29
lines changed

sqlx-core/src/ext/async_stream.rs

+92-29
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,133 @@
1+
//! A minimalist clone of the `async-stream` crate in 100% safe code, without proc macros.
2+
//!
3+
//! This was created initially to get around some weird compiler errors we were getting with
4+
//! `async-stream`, and now it'd just be more work to replace.
5+
16
use std::future::Future;
27
use std::pin::Pin;
8+
use std::sync::{Arc, Mutex};
39
use std::task::{Context, Poll};
410

5-
use futures_channel::mpsc;
611
use futures_core::future::BoxFuture;
712
use futures_core::stream::Stream;
8-
use futures_util::{pin_mut, FutureExt, SinkExt};
13+
use futures_core::FusedFuture;
14+
use futures_util::future::Fuse;
15+
use futures_util::FutureExt;
916

1017
use crate::error::Error;
1118

1219
pub struct TryAsyncStream<'a, T> {
13-
receiver: mpsc::Receiver<Result<T, Error>>,
14-
future: BoxFuture<'a, Result<(), Error>>,
20+
yielder: Yielder<T>,
21+
future: Fuse<BoxFuture<'a, Result<(), Error>>>,
1522
}
1623

1724
impl<'a, T> TryAsyncStream<'a, T> {
1825
pub fn new<F, Fut>(f: F) -> Self
1926
where
20-
F: FnOnce(mpsc::Sender<Result<T, Error>>) -> Fut + Send,
27+
F: FnOnce(Yielder<T>) -> Fut + Send,
2128
Fut: 'a + Future<Output = Result<(), Error>> + Send,
2229
T: 'a + Send,
2330
{
24-
let (mut sender, receiver) = mpsc::channel(0);
31+
let yielder = Yielder::new();
2532

26-
let future = f(sender.clone());
27-
let future = async move {
28-
if let Err(error) = future.await {
29-
let _ = sender.send(Err(error)).await;
30-
}
33+
let future = f(yielder.duplicate()).boxed().fuse();
34+
35+
Self { future, yielder }
36+
}
37+
}
38+
39+
pub struct Yielder<T> {
40+
// This mutex should never have any contention in normal operation.
41+
// It's just necessary to keep `Yielder` and thus `TryAsyncStream` send-able.
42+
value: Arc<Mutex<Option<T>>>,
43+
}
44+
45+
impl<T> Yielder<T> {
46+
fn new() -> Self {
47+
Yielder {
48+
value: Arc::new(Mutex::new(None)),
49+
}
50+
}
3151

32-
Ok(())
52+
// Don't want to expose a `Clone` impl
53+
fn duplicate(&self) -> Self {
54+
Yielder {
55+
value: self.value.clone(),
3356
}
34-
.fuse()
35-
.boxed();
57+
}
58+
59+
/// NOTE: may deadlock the task if called from outside the future passed to `TryAsyncStream`.
60+
pub async fn r#yield(&self, val: T) {
61+
let replaced = self
62+
.value
63+
.lock()
64+
.expect("BUG: panicked while holding a lock")
65+
.replace(val);
3666

37-
Self { future, receiver }
67+
debug_assert!(
68+
replaced.is_none(),
69+
"BUG: previously yielded value not taken"
70+
);
71+
72+
let mut yielded = false;
73+
74+
// Allows the generating future to suspend its execution without changing the task priority,
75+
// which would happen with `tokio::task::yield_now()`.
76+
//
77+
// Note that because this has no way to schedule a wakeup, this could deadlock the task
78+
// if called in the wrong place.
79+
futures_util::future::poll_fn(|_cx| {
80+
if !yielded {
81+
yielded = true;
82+
Poll::Pending
83+
} else {
84+
Poll::Ready(())
85+
}
86+
})
87+
.await
88+
}
89+
90+
fn take(&self) -> Option<T> {
91+
self.value
92+
.lock()
93+
.expect("BUG: panicked while holding a lock")
94+
.take()
3895
}
3996
}
4097

4198
impl<'a, T> Stream for TryAsyncStream<'a, T> {
4299
type Item = Result<T, Error>;
43100

44101
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
45-
let future = &mut self.future;
46-
pin_mut!(future);
47-
48-
// the future is fused so its safe to call forever
49-
// the future advances our "stream"
50-
// the future should be polled in tandem with the stream receiver
51-
let _ = future.poll(cx);
52-
53-
let receiver = &mut self.receiver;
54-
pin_mut!(receiver);
102+
if self.future.is_terminated() {
103+
return Poll::Ready(None);
104+
}
55105

56-
// then we check to see if we have anything to return
57-
receiver.poll_next(cx)
106+
match self.future.poll_unpin(cx) {
107+
Poll::Ready(Ok(())) => {
108+
// Future returned without yielding another value.
109+
Poll::Ready(None)
110+
}
111+
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
112+
Poll::Pending => self
113+
.yielder
114+
.take()
115+
.map_or(Poll::Pending, |val| Poll::Ready(Some(Ok(val)))),
116+
}
58117
}
59118
}
60119

61120
#[macro_export]
62121
macro_rules! try_stream {
63122
($($block:tt)*) => {
64-
crate::ext::async_stream::TryAsyncStream::new(move |mut sender| async move {
123+
crate::ext::async_stream::TryAsyncStream::new(move |yielder| async move {
124+
// Anti-footgun: effectively pins `yielder` to this future to prevent any accidental
125+
// move to another task, which could deadlock.
126+
let ref yielder = yielder;
127+
65128
macro_rules! r#yield {
66129
($v:expr) => {{
67-
let _ = futures_util::sink::SinkExt::send(&mut sender, Ok($v)).await;
130+
yielder.r#yield($v).await;
68131
}}
69132
}
70133

0 commit comments

Comments
 (0)