Skip to content

Commit

Permalink
Add peek to TcpStream.
Browse files Browse the repository at this point in the history
  • Loading branch information
nickgian authored and marcbowes committed Jan 28, 2025
1 parent 122c0e6 commit 0aa4600
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 2 deletions.
13 changes: 13 additions & 0 deletions src/net/tcp/split_owned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ impl OwnedReadHalf {
pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
reunite(self, other)
}

/// Attempt to receive data on the socket, without removing that data from the queue, registering the current task for wakeup if data is not yet available.
pub fn poll_peek(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_peek(cx, buf)
}

pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf).await
}
}

/// Owned write half of a `TcpStream`, created by `into_split`.
Expand Down
56 changes: 54 additions & 2 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use bytes::{Buf, Bytes};
use std::future::poll_fn;
use std::{
fmt::Debug,
io::{self, Error, Result},
Expand All @@ -6,8 +8,6 @@ use std::{
sync::Arc,
task::{ready, Context, Poll},
};

use bytes::{Buf, Bytes};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
runtime::Handle,
Expand Down Expand Up @@ -165,6 +165,16 @@ impl TcpStream {
pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> {
Ok(())
}

/// Receives data on the socket from the remote address to which it is connected,
/// without removing that data from the queue. On success, returns the number of bytes peeked.
pub async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
self.read_half.peek(buf).await
}

pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<usize>> {
self.read_half.poll_peek(cx, buf)
}
}

pub(crate) struct ReadHalf {
Expand Down Expand Up @@ -234,6 +244,48 @@ impl ReadHalf {
Some(avail)
}
}

pub(crate) fn poll_peek(
&mut self,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<Result<usize>> {
if self.is_closed || buf.capacity() == 0 {
return Poll::Ready(Ok(0));
}

// If we have buffered data, peek from it
if let Some(bytes) = &self.rx.buffer {
let len = std::cmp::min(bytes.len(), buf.remaining());
buf.put_slice(&bytes[..len]);
return Poll::Ready(Ok(len));
}

match ready!(self.rx.recv.poll_recv(cx)) {
Some(seg) => match seg {
SequencedSegment::Data(bytes) => {
let len = std::cmp::min(bytes.len(), buf.remaining());
buf.put_slice(&bytes[..len]);
self.rx.buffer = Some(bytes);

Poll::Ready(Ok(len))
}
SequencedSegment::Fin => {
self.is_closed = true;
Poll::Ready(Ok(0))
}
},
None => Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"Connection reset",
))),
}
}

pub(crate) async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
let mut buf = ReadBuf::new(buf);
poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
}
}

impl Debug for ReadHalf {
Expand Down
139 changes: 139 additions & 0 deletions tests/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,145 @@ fn split() -> Result {
sim.run()
}

#[test]
fn peek_empty_buffer() -> Result {
let mut sim = Builder::new().build();

sim.client("server", async move {
let listener = bind().await?;
let _ = listener.accept().await?;
Ok(())
});

sim.client("client", async move {
let mut s = TcpStream::connect(("server", PORT)).await?;

// no-op peek with empty buffer
let mut buf = [0; 0];
let n = s.peek(&mut buf).await?;
assert_eq!(0, n);

Ok(())
});

sim.run()
}

#[test]
fn peek_then_read() -> Result {
let mut sim = Builder::new().build();

sim.client("server", async move {
let listener = bind().await?;
let (mut s, _) = listener.accept().await?;

s.write_u64(1234).await?;
Ok(())
});

sim.client("client", async move {
let mut s = TcpStream::connect(("server", PORT)).await?;

// peek full message
let mut peek_buf = [0; 8];
assert_eq!(8, s.peek(&mut peek_buf).await?);
assert_eq!(1234u64, u64::from_be_bytes(peek_buf));

// peek again should see same data
let mut peek_buf2 = [0; 8];
assert_eq!(8, s.peek(&mut peek_buf2).await?);
assert_eq!(1234u64, u64::from_be_bytes(peek_buf2));

// read should consume the data
assert_eq!(1234, s.read_u64().await?);
let mut buf = [0; 8];
assert!(matches!(s.read(&mut buf).await, Ok(0)));

Ok(())
});

sim.run()
}

#[test]
fn peek_partial() -> Result {
let mut sim = Builder::new().build();

sim.client("server", async move {
let listener = bind().await?;
let (mut s, _) = listener.accept().await?;

s.write_all(&[0, 0, 1, 1]).await?;
Ok(())
});

sim.client("client", async move {
let mut s = TcpStream::connect(("server", PORT)).await?;

// peek with smaller buffer
let mut peek_buf = [0; 2];
assert_eq!(2, s.peek(&mut peek_buf).await?);
assert_eq!([0, 0], peek_buf);

// peek with larger buffer should still see all data
let mut peek_buf2 = [0; 4];
assert_eq!(4, s.peek(&mut peek_buf2).await?);
assert_eq!([0, 0, 1, 1], peek_buf2);

// read partial
let mut read_buf = [0; 2];
assert_eq!(2, s.read(&mut read_buf).await?);
assert_eq!([0, 0], read_buf);

// peek remaining
let mut peek_buf3 = [0; 2];
assert_eq!(2, s.peek(&mut peek_buf3).await?);
assert_eq!([1, 1], peek_buf3);

Ok(())
});

sim.run()
}

#[test]
fn peek_multiple_messages() -> Result {
let mut sim = Builder::new().build();

sim.client("server", async move {
let listener = bind().await?;
let (mut s, _) = listener.accept().await?;

s.write_u64(1234).await?;
s.write_u64(5678).await?;
Ok(())
});

sim.client("client", async move {
let mut s = TcpStream::connect(("server", PORT)).await?;

// peek first message
let mut peek_buf = [0; 8];
assert_eq!(8, s.peek(&mut peek_buf).await?);
assert_eq!(1234u64, u64::from_be_bytes(peek_buf));

// read first message
assert_eq!(1234, s.read_u64().await?);

// peek second message
let mut peek_buf2 = [0; 8];
assert_eq!(8, s.peek(&mut peek_buf2).await?);
assert_eq!(5678u64, u64::from_be_bytes(peek_buf2));

// read second message
assert_eq!(5678, s.read_u64().await?);

Ok(())
});

sim.run()
}

// # IpVersion specific tests

#[test]
Expand Down

0 comments on commit 0aa4600

Please sign in to comment.