From 1eeec111b92dffcca97d3f4f35d7178918af8dae Mon Sep 17 00:00:00 2001 From: Nick Giannarakis Date: Mon, 30 Dec 2024 19:54:42 +0200 Subject: [PATCH 1/2] Add peek to TcpStream. --- src/net/tcp/split_owned.rs | 13 ++++ src/net/tcp/stream.rs | 56 ++++++++++++++- tests/tcp.rs | 139 +++++++++++++++++++++++++++++++++++++ 3 files changed, 206 insertions(+), 2 deletions(-) diff --git a/src/net/tcp/split_owned.rs b/src/net/tcp/split_owned.rs index 0c7bbb1..5f3b83b 100644 --- a/src/net/tcp/split_owned.rs +++ b/src/net/tcp/split_owned.rs @@ -36,6 +36,19 @@ impl OwnedReadHalf { pub fn reunite(self, other: OwnedWriteHalf) -> Result { reunite(self, other) } + + /// Attempt to receive data on the socket, without removing that data from the queue, registering the current task for wakeup if data is not yet available. + pub fn poll_peek( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll> { + Pin::new(&mut self.inner).poll_peek(cx, buf) + } + + pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result { + self.inner.peek(buf).await + } } /// Owned write half of a `TcpStream`, created by `into_split`. diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 73b37fa..78c1701 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -1,3 +1,5 @@ +use bytes::{Buf, Bytes}; +use std::future::poll_fn; use std::{ fmt::Debug, io::{self, Error, Result}, @@ -6,8 +8,6 @@ use std::{ sync::Arc, task::{ready, Context, Poll}, }; - -use bytes::{Buf, Bytes}; use tokio::{ io::{AsyncRead, AsyncWrite, ReadBuf}, runtime::Handle, @@ -165,6 +165,16 @@ impl TcpStream { pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> { Ok(()) } + + /// Receives data on the socket from the remote address to which it is connected, + /// without removing that data from the queue. On success, returns the number of bytes peeked. + pub async fn peek(&mut self, buf: &mut [u8]) -> Result { + self.read_half.peek(buf).await + } + + pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll> { + self.read_half.poll_peek(cx, buf) + } } pub(crate) struct ReadHalf { @@ -234,6 +244,48 @@ impl ReadHalf { Some(avail) } } + + pub(crate) fn poll_peek( + &mut self, + cx: &mut Context<'_>, + buf: &mut ReadBuf, + ) -> Poll> { + if self.is_closed || buf.capacity() == 0 { + return Poll::Ready(Ok(0)); + } + + // If we have buffered data, peek from it + if let Some(bytes) = &self.rx.buffer { + let len = std::cmp::min(bytes.len(), buf.remaining()); + buf.put_slice(&bytes[..len]); + return Poll::Ready(Ok(len)); + } + + match ready!(self.rx.recv.poll_recv(cx)) { + Some(seg) => match seg { + SequencedSegment::Data(bytes) => { + let len = std::cmp::min(bytes.len(), buf.remaining()); + buf.put_slice(&bytes[..len]); + self.rx.buffer = Some(bytes); + + Poll::Ready(Ok(len)) + } + SequencedSegment::Fin => { + self.is_closed = true; + Poll::Ready(Ok(0)) + } + }, + None => Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "Connection reset", + ))), + } + } + + pub(crate) async fn peek(&mut self, buf: &mut [u8]) -> Result { + let mut buf = ReadBuf::new(buf); + poll_fn(|cx| self.poll_peek(cx, &mut buf)).await + } } impl Debug for ReadHalf { diff --git a/tests/tcp.rs b/tests/tcp.rs index 54baa37..8ce5aa6 100644 --- a/tests/tcp.rs +++ b/tests/tcp.rs @@ -749,6 +749,145 @@ fn split() -> Result { sim.run() } +#[test] +fn peek_empty_buffer() -> Result { + let mut sim = Builder::new().build(); + + sim.client("server", async move { + let listener = bind().await?; + let _ = listener.accept().await?; + Ok(()) + }); + + sim.client("client", async move { + let mut s = TcpStream::connect(("server", PORT)).await?; + + // no-op peek with empty buffer + let mut buf = [0; 0]; + let n = s.peek(&mut buf).await?; + assert_eq!(0, n); + + Ok(()) + }); + + sim.run() +} + +#[test] +fn peek_then_read() -> Result { + let mut sim = Builder::new().build(); + + sim.client("server", async move { + let listener = bind().await?; + let (mut s, _) = listener.accept().await?; + + s.write_u64(1234).await?; + Ok(()) + }); + + sim.client("client", async move { + let mut s = TcpStream::connect(("server", PORT)).await?; + + // peek full message + let mut peek_buf = [0; 8]; + assert_eq!(8, s.peek(&mut peek_buf).await?); + assert_eq!(1234u64, u64::from_be_bytes(peek_buf)); + + // peek again should see same data + let mut peek_buf2 = [0; 8]; + assert_eq!(8, s.peek(&mut peek_buf2).await?); + assert_eq!(1234u64, u64::from_be_bytes(peek_buf2)); + + // read should consume the data + assert_eq!(1234, s.read_u64().await?); + let mut buf = [0; 8]; + assert!(matches!(s.read(&mut buf).await, Ok(0))); + + Ok(()) + }); + + sim.run() +} + +#[test] +fn peek_partial() -> Result { + let mut sim = Builder::new().build(); + + sim.client("server", async move { + let listener = bind().await?; + let (mut s, _) = listener.accept().await?; + + s.write_all(&[0, 0, 1, 1]).await?; + Ok(()) + }); + + sim.client("client", async move { + let mut s = TcpStream::connect(("server", PORT)).await?; + + // peek with smaller buffer + let mut peek_buf = [0; 2]; + assert_eq!(2, s.peek(&mut peek_buf).await?); + assert_eq!([0, 0], peek_buf); + + // peek with larger buffer should still see all data + let mut peek_buf2 = [0; 4]; + assert_eq!(4, s.peek(&mut peek_buf2).await?); + assert_eq!([0, 0, 1, 1], peek_buf2); + + // read partial + let mut read_buf = [0; 2]; + assert_eq!(2, s.read(&mut read_buf).await?); + assert_eq!([0, 0], read_buf); + + // peek remaining + let mut peek_buf3 = [0; 2]; + assert_eq!(2, s.peek(&mut peek_buf3).await?); + assert_eq!([1, 1], peek_buf3); + + Ok(()) + }); + + sim.run() +} + +#[test] +fn peek_multiple_messages() -> Result { + let mut sim = Builder::new().build(); + + sim.client("server", async move { + let listener = bind().await?; + let (mut s, _) = listener.accept().await?; + + s.write_u64(1234).await?; + s.write_u64(5678).await?; + Ok(()) + }); + + sim.client("client", async move { + let mut s = TcpStream::connect(("server", PORT)).await?; + + // peek first message + let mut peek_buf = [0; 8]; + assert_eq!(8, s.peek(&mut peek_buf).await?); + assert_eq!(1234u64, u64::from_be_bytes(peek_buf)); + + // read first message + assert_eq!(1234, s.read_u64().await?); + + // peek second message + let mut peek_buf2 = [0; 8]; + assert_eq!(8, s.peek(&mut peek_buf2).await?); + assert_eq!(5678u64, u64::from_be_bytes(peek_buf2)); + + // read second message + assert_eq!(5678, s.read_u64().await?); + + Ok(()) + }); + + sim.run() +} + // # IpVersion specific tests #[test] From 5d88920d9bf290c1b1555f9760c8ce8513c5fdb7 Mon Sep 17 00:00:00 2001 From: Marc Bowes <15209+marcbowes@users.noreply.github.com> Date: Tue, 28 Jan 2025 13:54:58 -0800 Subject: [PATCH 2/2] Minor updates to peek - Ensure both peek + poll_peek have docs, and those docs match the tokio ones. - Add a trace point for peek --- src/net/tcp/split_owned.rs | 9 ++++++++- src/net/tcp/stream.rs | 36 +++++++++++++++++++++++------------- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/net/tcp/split_owned.rs b/src/net/tcp/split_owned.rs index 5f3b83b..5145c16 100644 --- a/src/net/tcp/split_owned.rs +++ b/src/net/tcp/split_owned.rs @@ -37,7 +37,9 @@ impl OwnedReadHalf { reunite(self, other) } - /// Attempt to receive data on the socket, without removing that data from the queue, registering the current task for wakeup if data is not yet available. + /// Attempts to receive data on the socket, without removing that data from + /// the queue, registering the current task for wakeup if data is not yet + /// available. pub fn poll_peek( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -46,6 +48,11 @@ impl OwnedReadHalf { Pin::new(&mut self.inner).poll_peek(cx, buf) } + /// Receives data on the socket from the remote address to which it is + /// connected, without removing that data from the queue. On success, + /// returns the number of bytes peeked. + /// + /// Successive calls return the same data. pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result { self.inner.peek(buf).await } diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 78c1701..fd9ca4f 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -166,12 +166,18 @@ impl TcpStream { Ok(()) } - /// Receives data on the socket from the remote address to which it is connected, - /// without removing that data from the queue. On success, returns the number of bytes peeked. + /// Receives data on the socket from the remote address to which it is + /// connected, without removing that data from the queue. On success, + /// returns the number of bytes peeked. + /// + /// Successive calls return the same data. pub async fn peek(&mut self, buf: &mut [u8]) -> Result { self.read_half.peek(buf).await } + /// Attempts to receive data on the socket, without removing that data from + /// the queue, registering the current task for wakeup if data is not yet + /// available. pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll> { self.read_half.poll_peek(cx, buf) } @@ -262,19 +268,23 @@ impl ReadHalf { } match ready!(self.rx.recv.poll_recv(cx)) { - Some(seg) => match seg { - SequencedSegment::Data(bytes) => { - let len = std::cmp::min(bytes.len(), buf.remaining()); - buf.put_slice(&bytes[..len]); - self.rx.buffer = Some(bytes); + Some(seg) => { + tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Peek"); - Poll::Ready(Ok(len)) - } - SequencedSegment::Fin => { - self.is_closed = true; - Poll::Ready(Ok(0)) + match seg { + SequencedSegment::Data(bytes) => { + let len = std::cmp::min(bytes.len(), buf.remaining()); + buf.put_slice(&bytes[..len]); + self.rx.buffer = Some(bytes); + + Poll::Ready(Ok(len)) + } + SequencedSegment::Fin => { + self.is_closed = true; + Poll::Ready(Ok(0)) + } } - }, + } None => Poll::Ready(Err(io::Error::new( io::ErrorKind::ConnectionReset, "Connection reset",