Skip to content

Commit

Permalink
Initialize Rust decompression support (google#29)
Browse files Browse the repository at this point in the history
Rust support for compressed shard files.

---------

Co-authored-by: Julien Cretin <[email protected]>
  • Loading branch information
wsxrdv and ia0 authored Oct 8, 2024
1 parent 59c3a77 commit f05477f
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 51 deletions.
26 changes: 26 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ crate-type = ["cdylib", "rlib"]

[dependencies]
flatbuffers = "24.3"
lz4_flex = { version = "0.11.3", default-features = false , features = ["frame"] }
numpy = "0.21"
pyo3 = "0.21"
rand = "0.8"
Expand Down
47 changes: 43 additions & 4 deletions rust/src/example_iteration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::io::Read;

use yoke::Yoke;

pub use super::parallel_map::parallel_map;
Expand All @@ -25,6 +27,30 @@ pub struct ExampleIterator {
example_iterator: Box<dyn Iterator<Item = Example> + Send>,
}

#[derive(Clone, Copy, Debug)]
pub enum CompressionType {
Uncompressed,
LZ4,
}

impl std::str::FromStr for CompressionType {
type Err = String;

fn from_str(input: &str) -> Result<Self, Self::Err> {
match input {
"" => Ok(CompressionType::Uncompressed),
"LZ4" => Ok(CompressionType::LZ4),
_ => Err("{input} unimplemented".to_string()),
}
}
}

#[derive(Clone, Debug)]
pub struct ShardInfo {
pub file_path: String,
pub compression_type: CompressionType,
}

impl ExampleIterator {
/// Takes a vector of file names of shards and creates an ExampleIterator over those. We assume
/// that all shard file names fit in memory. Alternatives to be re-evaluated:
Expand All @@ -33,7 +59,7 @@ impl ExampleIterator {
/// - Iterate over the shards in Rust. This would require having the shard filtering being
/// allowed to be called from Rust. But then we could pass an iterator of the following form:
/// `files: impl Iterator<Item = &str>`.
pub fn new(files: Vec<String>, repeat: bool, threads: usize) -> Self {
pub fn new(files: Vec<ShardInfo>, repeat: bool, threads: usize) -> Self {
assert!(!repeat, "Not implemented yet: repeat=true");
let example_iterator = Box::new(
parallel_map(|x| get_shard_progress(&x), files.into_iter(), threads).flatten(),
Expand All @@ -57,10 +83,23 @@ struct ShardProgress {
shard: LoadedShard,
}

/// Return a vector of bytes with the file content.
fn get_file_bytes(shard_info: &ShardInfo) -> Vec<u8> {
match shard_info.compression_type {
CompressionType::Uncompressed => std::fs::read(&shard_info.file_path).unwrap(),
CompressionType::LZ4 => {
let mut file_bytes = Vec::new();
lz4_flex::frame::FrameDecoder::new(std::fs::File::open(&shard_info.file_path).unwrap())
.read_to_end(&mut file_bytes)
.unwrap();
file_bytes
}
}
}

/// Get ShardProgress.
fn get_shard_progress(file_path: &str) -> ShardProgress {
// TODO compressed file support.
let file_bytes = std::fs::read(file_path).unwrap();
fn get_shard_progress(shard_info: &ShardInfo) -> ShardProgress {
let file_bytes = get_file_bytes(shard_info);

// A shard is a vector of examples (positive number -- invariant kept by Python code).
// An example is vector of attributes (the same number of attributes in each example of each
Expand Down
12 changes: 9 additions & 3 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ mod shard_generated;
/// Python wrappers around `example_iteration`.
mod static_iter {
use std::collections::HashMap;
use std::str::FromStr;

use numpy::IntoPyArray;
use pyo3::prelude::*;
use pyo3::{pyclass, pymethods, PyRefMut};

use super::example_iteration::ExampleIterator;
use super::example_iteration::{CompressionType, ExampleIterator, ShardInfo};

/// Implementation details: The goal is to own the ExampleIterator in Rust and only send
/// examples to Python. This helps with concurrent reading and parsing of shard files.
Expand Down Expand Up @@ -92,10 +93,15 @@ mod static_iter {
#[pymethods]
impl RustIter {
#[new]
fn new(files: Vec<String>, repeat: bool, threads: usize) -> Self {
fn new(files: Vec<String>, repeat: bool, threads: usize, compression: String) -> Self {
let static_index = rand::random();
let mut hash_map = STATIC_ITERATORS.lock().unwrap();
hash_map.insert(static_index, ExampleIterator::new(files, repeat, threads));
let compression_type = CompressionType::from_str(&compression).unwrap();
let shard_infos = files
.into_iter()
.map(|file_path| ShardInfo { file_path: file_path.clone(), compression_type })
.collect();
hash_map.insert(static_index, ExampleIterator::new(shard_infos, repeat, threads));

RustIter { static_index, can_iterate: false }
}
Expand Down
47 changes: 12 additions & 35 deletions rust/src/shard_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ pub mod sedpack {
type Inner = Attribute<'a>;
#[inline]
unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {
Self {
_tab: flatbuffers::Table::new(buf, loc),
}
Self { _tab: flatbuffers::Table::new(buf, loc) }
}
}

Expand Down Expand Up @@ -102,8 +100,7 @@ pub mod sedpack {
impl flatbuffers::Verifiable for Attribute<'_> {
#[inline]
fn run_verifier(
v: &mut flatbuffers::Verifier,
pos: usize,
v: &mut flatbuffers::Verifier, pos: usize,
) -> Result<(), flatbuffers::InvalidFlatbuffer> {
use self::flatbuffers::Verifiable;
v.visit_table(pos)?
Expand All @@ -119,9 +116,7 @@ pub mod sedpack {
impl<'a> Default for AttributeArgs<'a> {
#[inline]
fn default() -> Self {
AttributeArgs {
attribute_bytes: None,
}
AttributeArgs { attribute_bytes: None }
}
}

Expand All @@ -145,10 +140,7 @@ pub mod sedpack {
_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>,
) -> AttributeBuilder<'a, 'b, A> {
let start = _fbb.start_table();
AttributeBuilder {
fbb_: _fbb,
start_: start,
}
AttributeBuilder { fbb_: _fbb, start_: start }
}
#[inline]
pub fn finish(self) -> flatbuffers::WIPOffset<Attribute<'a>> {
Expand All @@ -175,9 +167,7 @@ pub mod sedpack {
type Inner = Example<'a>;
#[inline]
unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {
Self {
_tab: flatbuffers::Table::new(buf, loc),
}
Self { _tab: flatbuffers::Table::new(buf, loc) }
}
}

Expand Down Expand Up @@ -224,8 +214,7 @@ pub mod sedpack {
impl flatbuffers::Verifiable for Example<'_> {
#[inline]
fn run_verifier(
v: &mut flatbuffers::Verifier,
pos: usize,
v: &mut flatbuffers::Verifier, pos: usize,
) -> Result<(), flatbuffers::InvalidFlatbuffer> {
use self::flatbuffers::Verifiable;
v.visit_table(pos)?
Expand Down Expand Up @@ -272,10 +261,7 @@ pub mod sedpack {
_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>,
) -> ExampleBuilder<'a, 'b, A> {
let start = _fbb.start_table();
ExampleBuilder {
fbb_: _fbb,
start_: start,
}
ExampleBuilder { fbb_: _fbb, start_: start }
}
#[inline]
pub fn finish(self) -> flatbuffers::WIPOffset<Example<'a>> {
Expand All @@ -293,7 +279,6 @@ pub mod sedpack {
}
pub enum ShardOffset {}


// Added Yokeable to the autogenerated code.
#[derive(Copy, Clone, PartialEq, yoke::Yokeable)]
pub struct Shard<'a> {
Expand All @@ -304,9 +289,7 @@ pub mod sedpack {
type Inner = Shard<'a>;
#[inline]
unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {
Self {
_tab: flatbuffers::Table::new(buf, loc),
}
Self { _tab: flatbuffers::Table::new(buf, loc) }
}
}

Expand Down Expand Up @@ -353,8 +336,7 @@ pub mod sedpack {
impl flatbuffers::Verifiable for Shard<'_> {
#[inline]
fn run_verifier(
v: &mut flatbuffers::Verifier,
pos: usize,
v: &mut flatbuffers::Verifier, pos: usize,
) -> Result<(), flatbuffers::InvalidFlatbuffer> {
use self::flatbuffers::Verifiable;
v.visit_table(pos)?
Expand Down Expand Up @@ -401,10 +383,7 @@ pub mod sedpack {
_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>,
) -> ShardBuilder<'a, 'b, A> {
let start = _fbb.start_table();
ShardBuilder {
fbb_: _fbb,
start_: start,
}
ShardBuilder { fbb_: _fbb, start_: start }
}
#[inline]
pub fn finish(self) -> flatbuffers::WIPOffset<Shard<'a>> {
Expand Down Expand Up @@ -450,8 +429,7 @@ pub mod sedpack {
/// previous, unchecked, behavior use
/// `root_as_shard_unchecked`.
pub fn root_as_shard_with_opts<'b, 'o>(
opts: &'o flatbuffers::VerifierOptions,
buf: &'b [u8],
opts: &'o flatbuffers::VerifierOptions, buf: &'b [u8],
) -> Result<Shard<'b>, flatbuffers::InvalidFlatbuffer> {
flatbuffers::root_with_opts::<Shard<'b>>(opts, buf)
}
Expand All @@ -463,8 +441,7 @@ pub mod sedpack {
/// previous, unchecked, behavior use
/// `root_as_shard_unchecked`.
pub fn size_prefixed_root_as_shard_with_opts<'b, 'o>(
opts: &'o flatbuffers::VerifierOptions,
buf: &'b [u8],
opts: &'o flatbuffers::VerifierOptions, buf: &'b [u8],
) -> Result<Shard<'b>, flatbuffers::InvalidFlatbuffer> {
flatbuffers::size_prefixed_root_with_opts::<Shard<'b>>(opts, buf)
}
Expand Down
16 changes: 10 additions & 6 deletions src/sedpack/io/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _shard_info_iterator(
for child in shard_list.children_shard_lists:
yield from self._shard_info_iterator(child)

def shard_info_iterator(self, split: SplitT) -> Iterator[ShardInfo]:
def shard_info_iterator(self, split: SplitT | None) -> Iterator[ShardInfo]:
"""Iterate all `ShardInfo` in the split.
Args:
Expand All @@ -151,10 +151,14 @@ def shard_info_iterator(self, split: SplitT) -> Iterator[ShardInfo]:
Raises: ValueError when the split is not present. A split not being
present is different from there not being any shard.
"""
if split not in self._dataset_info.splits:
# Split not present.
raise ValueError(f"There is no shard in {split}.")
if split:
if split not in self._dataset_info.splits:
# Split not present.
raise ValueError(f"There is no shard in {split}.")

shard_list_info: ShardListInfo = self._dataset_info.splits[split]
shard_list_info: ShardListInfo = self._dataset_info.splits[split]

yield from self._shard_info_iterator(shard_list_info)
yield from self._shard_info_iterator(shard_list_info)
else:
for shard_list_info in self._dataset_info.splits.values():
yield from self._shard_info_iterator(shard_list_info)
9 changes: 6 additions & 3 deletions src/sedpack/io/dataset_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,9 +658,12 @@ def to_dict(example):
)
return result

with _sedpack_rs.RustIter(files=shard_paths,
repeat=repeat,
threads=file_parallelism) as rust_iter:
with _sedpack_rs.RustIter(
files=shard_paths,
repeat=repeat,
threads=file_parallelism,
compression=self.dataset_structure.compression,
) as rust_iter:
example_iterator = map(to_dict, iter(rust_iter))
if process_record:
yield from map(process_record, example_iterator)
Expand Down
9 changes: 9 additions & 0 deletions tests/io/test_rust_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,12 @@ def test_end2end_as_numpy_iterator_fb(tmpdir: Union[str, Path]) -> None:
shard_file_type="fb",
compression="",
)


def test_end2end_as_numpy_iterator_fb_lz4(tmpdir: Union[str, Path]) -> None:
end2end(
tmpdir=tmpdir,
dtype="float32",
shard_file_type="fb",
compression="LZ4",
)

0 comments on commit f05477f

Please sign in to comment.