diff --git a/Cargo.lock b/Cargo.lock index 303d21a0abe..4cad1d24765 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2772,7 +2772,7 @@ dependencies = [ [[package]] name = "libp2p-relay" -version = "0.15.1" +version = "0.15.2" dependencies = [ "asynchronous-codec", "bytes", diff --git a/protocols/relay/CHANGELOG.md b/protocols/relay/CHANGELOG.md index c39e7014bb6..25cc8af2395 100644 --- a/protocols/relay/CHANGELOG.md +++ b/protocols/relay/CHANGELOG.md @@ -1,3 +1,10 @@ +## 0.15.2 - unreleased + +- As a relay, when forwarding data between relay-connection-source and -destination and vice versa, flush write side when read currently has no more data available. + See [PR 3765]. + +[PR 3765]: https://github.com/libp2p/rust-libp2p/pull/3765 + ## 0.15.1 - Migrate from `prost` to `quick-protobuf`. This removes `protoc` dependency. See [PR 3312]. diff --git a/protocols/relay/Cargo.toml b/protocols/relay/Cargo.toml index 357b1b65174..9bb92bb8961 100644 --- a/protocols/relay/Cargo.toml +++ b/protocols/relay/Cargo.toml @@ -3,7 +3,7 @@ name = "libp2p-relay" edition = "2021" rust-version = "1.62.0" description = "Communications relaying for libp2p" -version = "0.15.1" +version = "0.15.2" authors = ["Parity Technologies ", "Max Inden "] license = "MIT" repository = "https://github.com/libp2p/rust-libp2p" diff --git a/protocols/relay/src/copy_future.rs b/protocols/relay/src/copy_future.rs index 12a8c486d3a..2f9eabed349 100644 --- a/protocols/relay/src/copy_future.rs +++ b/protocols/relay/src/copy_future.rs @@ -132,7 +132,14 @@ fn forward_data( mut dst: &mut D, cx: &mut Context<'_>, ) -> Poll> { - let buffer = ready!(Pin::new(&mut src).poll_fill_buf(cx))?; + let buffer = match Pin::new(&mut src).poll_fill_buf(cx)? { + Poll::Ready(buffer) => buffer, + Poll::Pending => { + let _ = Pin::new(&mut dst).poll_flush(cx)?; + return Poll::Pending; + } + }; + if buffer.is_empty() { ready!(Pin::new(&mut dst).poll_flush(cx))?; ready!(Pin::new(&mut dst).poll_close(cx))?; @@ -150,95 +157,59 @@ fn forward_data( #[cfg(test)] mod tests { - use super::CopyFuture; + use super::*; use futures::executor::block_on; - use futures::io::{AsyncRead, AsyncWrite}; + use futures::io::{AsyncRead, AsyncWrite, BufReader, BufWriter}; use quickcheck::QuickCheck; use std::io::ErrorKind; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; - struct Connection { - read: Vec, - write: Vec, - } - - impl AsyncWrite for Connection { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.write).poll_write(cx, buf) - } - - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.write).poll_flush(cx) - } - - fn poll_close( - mut self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.write).poll_close(cx) - } - } - - impl AsyncRead for Connection { - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - let n = std::cmp::min(self.read.len(), buf.len()); - buf[0..n].copy_from_slice(&self.read[0..n]); - self.read = self.read.split_off(n); - Poll::Ready(Ok(n)) + #[test] + fn quickcheck() { + struct Connection { + read: Vec, + write: Vec, } - } - - struct PendingConnection {} - impl AsyncWrite for PendingConnection { - fn poll_write( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &[u8], - ) -> Poll> { - Poll::Pending - } + impl AsyncWrite for Connection { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.write).poll_write(cx, buf) + } - fn poll_flush( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Pending - } + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.write).poll_flush(cx) + } - fn poll_close( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Pending + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.write).poll_close(cx) + } } - } - impl AsyncRead for PendingConnection { - fn poll_read( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &mut [u8], - ) -> Poll> { - Poll::Pending + impl AsyncRead for Connection { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let n = std::cmp::min(self.read.len(), buf.len()); + buf[0..n].copy_from_slice(&self.read[0..n]); + self.read = self.read.split_off(n); + Poll::Ready(Ok(n)) + } } - } - #[test] - fn quickcheck() { fn prop(a: Vec, b: Vec, max_circuit_bytes: u64) { let connection_a = Connection { read: a.clone(), @@ -275,6 +246,42 @@ mod tests { #[test] fn max_circuit_duration() { + struct PendingConnection {} + + impl AsyncWrite for PendingConnection { + fn poll_write( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Pending + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Pending + } + } + + impl AsyncRead for PendingConnection { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut [u8], + ) -> Poll> { + Poll::Pending + } + } + let copy_future = CopyFuture::new( PendingConnection {}, PendingConnection {}, @@ -288,4 +295,124 @@ mod tests { block_on(copy_future).expect_err("Expect maximum circuit duration to be reached."); assert_eq!(error.kind(), ErrorKind::TimedOut); } + + #[test] + fn forward_data_should_flush_on_pending_source() { + struct NeverEndingSource { + read: Vec, + } + + impl AsyncRead for NeverEndingSource { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + if let Some(b) = self.read.pop() { + buf[0] = b; + return Poll::Ready(Ok(1)); + } + + Poll::Pending + } + } + + struct RecordingDestination { + method_calls: Vec, + } + + #[derive(Debug, PartialEq)] + enum Method { + Write(Vec), + Flush, + Close, + } + + impl AsyncWrite for RecordingDestination { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.method_calls.push(Method::Write(buf.to_vec())); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + self.method_calls.push(Method::Flush); + Poll::Ready(Ok(())) + } + + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + self.method_calls.push(Method::Close); + Poll::Ready(Ok(())) + } + } + + // The source has two reads available, handing them out on `AsyncRead::poll_read` one by one. + let mut source = BufReader::new(NeverEndingSource { read: vec![1, 2] }); + + // The destination is wrapped by a `BufWriter` with a capacity of `3`, i.e. one larger than + // the available reads of the source. Without an explicit `AsyncWrite::poll_flush` the two + // reads would thus never make it to the destination, but instead be stuck in the buffer of + // the `BufWrite`. + let mut destination = BufWriter::with_capacity( + 3, + RecordingDestination { + method_calls: vec![], + }, + ); + + let mut cx = Context::from_waker(futures::task::noop_waker_ref()); + + assert!( + matches!( + forward_data(&mut source, &mut destination, &mut cx), + Poll::Ready(Ok(1)), + ), + "Expect `forward_data` to forward one read from the source to the wrapped destination." + ); + assert_eq!( + destination.get_ref().method_calls.as_slice(), &[], + "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \ + the destination. The source might have more data available, thus `forward_data` has not \ + yet flushed.", + ); + + assert!( + matches!( + forward_data(&mut source, &mut destination, &mut cx), + Poll::Ready(Ok(1)), + ), + "Expect `forward_data` to forward one read from the source to the wrapped destination." + ); + assert_eq!( + destination.get_ref().method_calls.as_slice(), &[], + "Given that destination is wrapped with a `BufWrite`, the write doesn't (yet) make it to \ + the destination. The source might have more data available, thus `forward_data` has not \ + yet flushed.", + ); + + assert!( + matches!( + forward_data(&mut source, &mut destination, &mut cx), + Poll::Pending, + ), + "The source has no more reads available, but does not close i.e. does not return \ + `Poll::Ready(Ok(1))` but instead `Poll::Pending`. Thus `forward_data` returns \ + `Poll::Pending` as well." + ); + assert_eq!( + destination.get_ref().method_calls.as_slice(), + &[Method::Write(vec![2, 1]), Method::Flush], + "Given that source had no more reads, `forward_data` calls flush, thus instructing the \ + `BufWriter` to flush the two buffered writes down to the destination." + ); + } }