Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add peek to TcpStream. #206

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/net/tcp/split_owned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,26 @@ impl OwnedReadHalf {
pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
reunite(self, other)
}

/// Attempts to receive data on the socket, without removing that data from
/// the queue, registering the current task for wakeup if data is not yet
/// available.
pub fn poll_peek(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_peek(cx, buf)
}

/// Receives data on the socket from the remote address to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
///
/// Successive calls return the same data.
pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf).await
}
}

/// Owned write half of a `TcpStream`, created by `into_split`.
Expand Down
66 changes: 64 additions & 2 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use bytes::{Buf, Bytes};
use std::future::poll_fn;
use std::{
fmt::Debug,
io::{self, Error, Result},
Expand All @@ -6,8 +8,6 @@ use std::{
sync::Arc,
task::{ready, Context, Poll},
};

use bytes::{Buf, Bytes};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
runtime::Handle,
Expand Down Expand Up @@ -165,6 +165,22 @@ impl TcpStream {
pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> {
Ok(())
}

/// Receives data on the socket from the remote address to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
///
/// Successive calls return the same data.
pub async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
self.read_half.peek(buf).await
}

/// Attempts to receive data on the socket, without removing that data from
/// the queue, registering the current task for wakeup if data is not yet
/// available.
pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<usize>> {
self.read_half.poll_peek(cx, buf)
}
}

pub(crate) struct ReadHalf {
Expand Down Expand Up @@ -234,6 +250,52 @@ impl ReadHalf {
Some(avail)
}
}

pub(crate) fn poll_peek(
&mut self,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<Result<usize>> {
if self.is_closed || buf.capacity() == 0 {
return Poll::Ready(Ok(0));
}

// If we have buffered data, peek from it
if let Some(bytes) = &self.rx.buffer {
let len = std::cmp::min(bytes.len(), buf.remaining());
buf.put_slice(&bytes[..len]);
return Poll::Ready(Ok(len));
}

match ready!(self.rx.recv.poll_recv(cx)) {
Some(seg) => {
tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Peek");

match seg {
SequencedSegment::Data(bytes) => {
let len = std::cmp::min(bytes.len(), buf.remaining());
buf.put_slice(&bytes[..len]);
self.rx.buffer = Some(bytes);

Poll::Ready(Ok(len))
}
SequencedSegment::Fin => {
self.is_closed = true;
Poll::Ready(Ok(0))
}
}
}
None => Poll::Ready(Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"Connection reset",
))),
}
}

pub(crate) async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
let mut buf = ReadBuf::new(buf);
poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
}
}

impl Debug for ReadHalf {
Expand Down
139 changes: 139 additions & 0 deletions tests/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,145 @@ fn split() -> Result {
sim.run()
}

#[test]
fn peek_empty_buffer() -> Result {
let mut sim = Builder::new().build();

sim.client("server", async move {
let listener = bind().await?;
let _ = listener.accept().await?;
Ok(())
});

sim.client("client", async move {
let mut s = TcpStream::connect(("server", PORT)).await?;

// no-op peek with empty buffer
let mut buf = [0; 0];
let n = s.peek(&mut buf).await?;
assert_eq!(0, n);

Ok(())
});

sim.run()
}

#[test]
fn peek_then_read() -> Result {
let mut sim = Builder::new().build();

sim.client("server", async move {
let listener = bind().await?;
let (mut s, _) = listener.accept().await?;

s.write_u64(1234).await?;
Ok(())
});

sim.client("client", async move {
let mut s = TcpStream::connect(("server", PORT)).await?;

// peek full message
let mut peek_buf = [0; 8];
assert_eq!(8, s.peek(&mut peek_buf).await?);
assert_eq!(1234u64, u64::from_be_bytes(peek_buf));

// peek again should see same data
let mut peek_buf2 = [0; 8];
assert_eq!(8, s.peek(&mut peek_buf2).await?);
assert_eq!(1234u64, u64::from_be_bytes(peek_buf2));

// read should consume the data
assert_eq!(1234, s.read_u64().await?);
let mut buf = [0; 8];
assert!(matches!(s.read(&mut buf).await, Ok(0)));

Ok(())
});

sim.run()
}

#[test]
fn peek_partial() -> Result {
let mut sim = Builder::new().build();

sim.client("server", async move {
let listener = bind().await?;
let (mut s, _) = listener.accept().await?;

s.write_all(&[0, 0, 1, 1]).await?;
Ok(())
});

sim.client("client", async move {
let mut s = TcpStream::connect(("server", PORT)).await?;

// peek with smaller buffer
let mut peek_buf = [0; 2];
assert_eq!(2, s.peek(&mut peek_buf).await?);
assert_eq!([0, 0], peek_buf);

// peek with larger buffer should still see all data
let mut peek_buf2 = [0; 4];
assert_eq!(4, s.peek(&mut peek_buf2).await?);
assert_eq!([0, 0, 1, 1], peek_buf2);

// read partial
let mut read_buf = [0; 2];
assert_eq!(2, s.read(&mut read_buf).await?);
assert_eq!([0, 0], read_buf);

// peek remaining
let mut peek_buf3 = [0; 2];
assert_eq!(2, s.peek(&mut peek_buf3).await?);
assert_eq!([1, 1], peek_buf3);

Ok(())
});

sim.run()
}

#[test]
fn peek_multiple_messages() -> Result {
let mut sim = Builder::new().build();

sim.client("server", async move {
let listener = bind().await?;
let (mut s, _) = listener.accept().await?;

s.write_u64(1234).await?;
s.write_u64(5678).await?;
Ok(())
});

sim.client("client", async move {
let mut s = TcpStream::connect(("server", PORT)).await?;

// peek first message
let mut peek_buf = [0; 8];
assert_eq!(8, s.peek(&mut peek_buf).await?);
assert_eq!(1234u64, u64::from_be_bytes(peek_buf));

// read first message
assert_eq!(1234, s.read_u64().await?);

// peek second message
let mut peek_buf2 = [0; 8];
assert_eq!(8, s.peek(&mut peek_buf2).await?);
assert_eq!(5678u64, u64::from_be_bytes(peek_buf2));

// read second message
assert_eq!(5678, s.read_u64().await?);

Ok(())
});

sim.run()
}

// # IpVersion specific tests

#[test]
Expand Down