Skip to content

Commit

Permalink
Add support for skippable frames
Browse files Browse the repository at this point in the history
  • Loading branch information
antoyo committed Jan 26, 2023
1 parent 1bed85b commit 511e7a7
Show file tree
Hide file tree
Showing 6 changed files with 453 additions and 1 deletion.
26 changes: 26 additions & 0 deletions src/stream/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ pub struct Status {
pub bytes_written: usize,
}

/// The magic variant encoded in a skippable frame.
pub struct MagicVariant(pub u8);

/// An in-memory decoder for streams of data.
pub struct Decoder<'a> {
context: zstd_safe::DCtx<'a>,
Expand Down Expand Up @@ -171,6 +174,29 @@ impl<'a> Decoder<'a> {
.map_err(map_error_code)?;
Ok(())
}

#[cfg(feature = "experimental")]
// TODO: remove self?
/// Read a skippable frame.
pub fn read_skippable_frame(&self, dest: &mut Vec<u8>, input: &[u8]) -> io::Result<(usize, MagicVariant)> {
use zstd_safe::DCtx;

let mut magic_variant = 0;
DCtx::read_skippable_frame(&mut OutBuffer::around(dest), &mut magic_variant, input)
.map(|written| (written, MagicVariant(magic_variant as u8)))
.map_err(map_error_code)
}

#[cfg(feature = "experimental")]
// TODO: remove self?
/// Check if a frame is skippable.
pub fn is_skippable_frame(&self, input: &[u8]) -> io::Result<bool> {
use zstd_safe::DCtx;

DCtx::is_skippable_frame(input)
.map(|is_skippable| is_skippable != 0)
.map_err(map_error_code)
}
}

impl Operation for Decoder<'_> {
Expand Down
219 changes: 218 additions & 1 deletion src/stream/read/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
//! Implement pull-based [`Read`] trait for both compressing and decompressing.
use std::io::{self, BufRead, BufReader, Read};
#[cfg(feature = "experimental")]
use std::cmp::min;
#[cfg(feature = "experimental")]
use std::io::SeekFrom;
use std::io::{self, BufRead, BufReader, Read, Seek};
#[cfg(feature = "experimental")]
use std::mem::size_of;

use crate::dict::{DecoderDictionary, EncoderDictionary};
use crate::stream::{raw, zio};
use zstd_safe;

#[cfg(feature = "experimental")]
use zstd_safe::{frame_header_size, MAGIC_SKIPPABLE_MASK, MAGIC_SKIPPABLE_START, SKIPPABLEHEADERSIZE};
#[cfg(feature = "experimental")]
use super::raw::MagicVariant;

#[cfg(test)]
mod tests;

#[cfg(feature = "experimental")]
const U24_SIZE: usize = size_of::<u16>() + size_of::<u8>();
#[cfg(feature = "experimental")]
const U32_SIZE: usize = size_of::<u32>();

/// A decoder that decompress input data from another `Read`.
///
/// This allows to read a stream of compressed data
Expand Down Expand Up @@ -45,6 +61,207 @@ impl<R: BufRead> Decoder<'static, R> {
Ok(Decoder { reader })
}
}

/// Read and discard `bytes_count` bytes in the reader.
#[cfg(feature = "experimental")]
fn consume<R: Read + ?Sized>(this: &mut R, mut bytes_count: usize) -> io::Result<()> {
let mut buf = [0; 100];
while bytes_count > 0 {
let end = min(buf.len(), bytes_count);
match this.read(&mut buf[..end]) {
Ok(0) => break,
Ok(n) => bytes_count -= n,
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {},
Err(e) => return Err(e),
}
}
if bytes_count > 0 {
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "failed to fill whole buffer"))
} else {
Ok(())
}
}

/// Like Read::read_exact(), but seek back to the starting position of the reader in case of an
/// error.
#[cfg(feature = "experimental")]
fn read_exact_or_seek_back<R: Read + Seek + ?Sized>(this: &mut R, mut buf: &mut [u8]) -> io::Result<()> {
let mut bytes_read = 0;
while !buf.is_empty() {
match this.read(buf) {
Ok(0) => break,
Ok(n) => {
bytes_read += n as i64;
let tmp = buf;
buf = &mut tmp[n..];
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => {
if let Err(error) = this.seek(SeekFrom::Current(-bytes_read)) {
panic!("Error while seeking back to the start: {}", error);
}
return Err(e)
},
}
}
if !buf.is_empty() {
if let Err(error) = this.seek(SeekFrom::Current(-bytes_read)) {
panic!("Error while seeking back to the start: {}", error);
}
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "failed to fill whole buffer"))
} else {
Ok(())
}
}

impl<'a, R: Read + Seek> Decoder<'a, BufReader<R>> {
#[cfg(feature = "experimental")]
fn read_skippable_frame_size(&mut self) -> io::Result<usize> {
let mut magic_buffer = [0u8; U32_SIZE];
read_exact_or_seek_back(self.reader.reader_mut(), &mut magic_buffer)?;

// Read skippable frame size.
let mut buffer = [0u8; U32_SIZE];
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
let content_size = u32::from_le_bytes(buffer) as usize;

self.seek_back(U32_SIZE * 2);

Ok(content_size + SKIPPABLEHEADERSIZE as usize)
}

#[cfg(feature = "experimental")]
fn seek_back(&mut self, bytes_count: usize) {
if let Err(error) = self.reader.reader_mut().seek(SeekFrom::Current(-(bytes_count as i64))) {
panic!("Error while seeking back to the start: {}", error);
}
}

#[cfg(feature = "experimental")]
/// Attempt to read a skippable frame and write its content to `dest`.
/// If it cannot read a skippable frame, the reader will be back to its starting position.
pub fn read_skippable_frame(&mut self, dest: &mut Vec<u8>) -> io::Result<(usize, MagicVariant)> {
let mut bytes_to_seek = 0;

let res = (|| {
let mut magic_buffer = [0u8; U32_SIZE];
read_exact_or_seek_back(self.reader.reader_mut(), &mut magic_buffer)?;
let magic_number = u32::from_le_bytes(magic_buffer);

// Read skippable frame size.
let mut buffer = [0u8; U32_SIZE];
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
let content_size = u32::from_le_bytes(buffer) as usize;

let op = self.reader.operation();
// FIXME: I feel like we should do that check right after reading the magic number, but
// ZSTD does it after reading the content size.
if !op.is_skippable_frame(&magic_buffer)? {
bytes_to_seek = U32_SIZE * 2;
return Err(io::Error::new(io::ErrorKind::Other, "Unsupported frame parameter"));
}
if content_size > dest.capacity() {
bytes_to_seek = U32_SIZE * 2;
return Err(io::Error::new(io::ErrorKind::Other, "Destination buffer is too small"));
}

if content_size > 0 {
dest.resize(content_size, 0);
read_exact_or_seek_back(self.reader.reader_mut(), dest)?;
}

Ok((magic_number, content_size))
})();

let (magic_number, content_size) =
match res {
Ok(data) => data,
Err(err) => {
if bytes_to_seek != 0 {
self.seek_back(bytes_to_seek);
}
return Err(err);
},
};

let magic_variant = magic_number - MAGIC_SKIPPABLE_START;

Ok((content_size, MagicVariant(magic_variant as u8)))
}

#[cfg(feature = "experimental")]
fn get_block_size(&mut self) -> io::Result<(usize, bool)> {
let mut buffer = [0u8; U24_SIZE];
self.reader.reader_mut().read_exact(&mut buffer)?;
let buffer = [buffer[0], buffer[1], buffer[2], 0];
let block_header = u32::from_le_bytes(buffer);
let compressed_size = block_header >> 3;
let last_block = block_header & 1;
self.seek_back(U24_SIZE);
Ok((compressed_size as usize, last_block != 0))
}

#[cfg(feature = "experimental")]
fn find_frame_compressed_size(&mut self) -> io::Result<usize> {
const ZSTD_BLOCK_HEADER_SIZE: usize = 3;

// TODO: should we support legacy format?
let mut magic_buffer = [0u8; U32_SIZE];
self.reader.reader_mut().read_exact(&mut magic_buffer)?;
let magic_number = u32::from_le_bytes(magic_buffer);
self.seek_back(U32_SIZE);
if magic_number & MAGIC_SKIPPABLE_MASK == MAGIC_SKIPPABLE_START {
self.read_skippable_frame_size()
}
else {
let mut bytes_read = 0;
let (header_size, checksum_flag) = self.frame_header_size()?;
bytes_read += header_size;
consume(self.reader.reader_mut(), header_size)?;

loop {
let (compressed_size, last_block) = self.get_block_size()?;
let block_size = ZSTD_BLOCK_HEADER_SIZE + compressed_size;
consume(self.reader.reader_mut(), block_size)?;
bytes_read += block_size;
if last_block {
break;
}
}

self.seek_back(bytes_read);

if checksum_flag {
bytes_read += 4;
}

Ok(bytes_read)
}
}

#[cfg(feature = "experimental")]
fn frame_header_size(&mut self) -> io::Result<(usize, bool)> {
use crate::map_error_code;
const MAX_FRAME_HEADER_SIZE_PREFIX: usize = 5;
let mut buffer = [0u8; MAX_FRAME_HEADER_SIZE_PREFIX];
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
let size = frame_header_size(&buffer)
.map_err(map_error_code)?;
let byte = buffer[MAX_FRAME_HEADER_SIZE_PREFIX - 1];
let checksum_flag = (byte >> 2) & 1;
self.seek_back(MAX_FRAME_HEADER_SIZE_PREFIX);
Ok((size, checksum_flag != 0))
}

#[cfg(feature = "experimental")]
/// Skip over a frame, without decompressing it.
pub fn skip_frame(&mut self) -> io::Result<()> {
let size = self.find_frame_compressed_size()?;
consume(self.reader.reader_mut(), size)?;
Ok(())
}
}

impl<'a, R: BufRead> Decoder<'a, R> {
/// Sets this `Decoder` to stop after the first frame.
///
Expand Down
12 changes: 12 additions & 0 deletions src/stream/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ impl<W: Write> Encoder<'static, W> {
let writer = zio::Writer::new(writer, encoder);
Ok(Encoder { writer })
}

/// Write a skippable frame.
#[cfg(feature = "experimental")]
pub fn write_skippable_frame(&mut self, buf: &[u8], magic_variant: u32) -> io::Result<()> {
self.writer.write_skippable_frame(buf, magic_variant)
}
}

impl<'a, W: Write> Encoder<'a, W> {
Expand Down Expand Up @@ -259,6 +265,12 @@ impl<'a, W: Write> Encoder<'a, W> {
self.try_finish().map_err(|(_, err)| err)
}

/// Useful to get back the writer after calling write_skippable_frame. You don't want to call
/// finish because this will create yet another frame.
pub fn into_inner(self) -> W {
self.writer.into_inner().0
}

/// **Required**: Attempts to finish the stream.
///
/// You *need* to finish the stream when you're done writing, either with
Expand Down
5 changes: 5 additions & 0 deletions src/stream/zio/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ impl<R, D> Reader<R, D> {
self.single_frame = true;
}

/// Returns a reference to the underlying operation.
pub fn operation(&self) -> &D {
&self.operation
}

/// Returns a mutable reference to the underlying operation.
pub fn operation_mut(&mut self) -> &mut D {
&mut self.operation
Expand Down
Loading

0 comments on commit 511e7a7

Please sign in to comment.