Skip to content

Commit

Permalink
feat: add method read_size
Browse files Browse the repository at this point in the history
  • Loading branch information
rklaehn committed Feb 2, 2023
1 parent e7ce384 commit bcc66f9
Showing 1 changed file with 50 additions and 5 deletions.
55 changes: 50 additions & 5 deletions src/bao_slice_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ impl SliceIter {
self.len = len;
}

/// get the length of the slice
pub fn len(&self) -> Option<u64> {
if self.res.is_some() {
Some(self.len)
} else {
None
}
}

// todo: it is easy to make this a proper iterator, and even possible
// to make it an iterator without any state, but let's just keep it simple
// for now.
Expand Down Expand Up @@ -512,13 +521,24 @@ impl<R: tokio::io::AsyncRead + Unpin> AsyncSliceDecoder<R> {
}
}

/// Read the size. This only does something if we are before the header,
/// otherwise it just returns the already known size.
pub async fn read_size(&mut self) -> io::Result<u64> {
if self.inner.iter.len().is_none() {
let mut tgt = ReadBuf::new(&mut []);
futures::future::poll_fn(|cx| Self::poll_read_inner(Pin::new(self), cx, &mut tgt))
.await?;
}
Ok(self.inner.iter.len().unwrap())
}

pub fn into_inner(self) -> R {
self.inner.into_inner()
}
}

impl<R: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for AsyncSliceDecoder<R> {
fn poll_read(
/// This is the poll_read implementation, except that it will make progress
/// even when passed an empty buffer.
fn poll_read_inner(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
tgt: &mut ReadBuf<'_>,
Expand Down Expand Up @@ -552,13 +572,25 @@ impl<R: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for AsyncSliceDecoder
self.current_item = None;
self.inner.buf_start = 0;
}
debug_assert!(n > 0, "we should have read something");
break Ok(());
});
res
}
}

impl<R: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for AsyncSliceDecoder<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
tgt: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if tgt.remaining() == 0 {
return Poll::Ready(Ok(()));
}
self.poll_read_inner(cx, tgt)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -637,9 +669,20 @@ mod tests {
// check that we have read the entire slice
assert_eq!(cursor.position(), encoded.len() as u64);

// test validation and reading
// test once with and once without reading size, to make sure that calling size is not required to drive
// the internal state machine

// test validation and reading - without reading size
let mut cursor = std::io::Cursor::new(&encoded);
let mut reader = AsyncSliceDecoder::new(&mut cursor, hash, 0, len);
let mut data = vec![];
reader.read_to_end(&mut data).await.unwrap();
assert_eq!(data, test_data);

// test validation and reading - with reading size
let mut cursor = std::io::Cursor::new(&encoded);
let mut reader = AsyncSliceDecoder::new(&mut cursor, hash, 0, len);
assert_eq!(reader.read_size().await.unwrap(), test_data.len() as u64);
let mut data = vec![];
reader.read_to_end(&mut data).await.unwrap();
assert_eq!(data, test_data);
Expand All @@ -654,6 +697,7 @@ mod tests {

// check that reading with a size > end works
let mut reader = AsyncSliceDecoder::new(&mut cursor, hash, 0, u64::MAX);
assert_eq!(reader.read_size().await.unwrap(), test_data.len() as u64);
let mut data = vec![];
reader.read_to_end(&mut data).await.unwrap();
assert_eq!(data, test_data);
Expand Down Expand Up @@ -709,6 +753,7 @@ mod tests {

let mut cursor = std::io::Cursor::new(&slice);
let mut reader = AsyncSliceDecoder::new(&mut cursor, hash, slice_start, slice_len);
assert_eq!(reader.read_size().await.unwrap(), test_data.len() as u64);
let mut data = vec![];
reader.read_to_end(&mut data).await.unwrap();
// check that we have read the entire slice
Expand Down

0 comments on commit bcc66f9

Please sign in to comment.