From 0b9407ed4ed25617c86f214a91ca3d201477a165 Mon Sep 17 00:00:00 2001 From: jojii Date: Thu, 28 Dec 2023 18:45:24 +0100 Subject: [PATCH] add decode_len and decode_len_unsafe and update dependencies --- Cargo.toml | 8 ++--- benches/varint_bench/main.rs | 50 ++++++++++++++++++++++++++- fuzz/Cargo.toml | 6 ++++ fuzz/fuzz_targets/fuzz_target_3.rs | 14 ++++++++ src/decode/mod.rs | 55 ++++++++++++++++++++++++++++++ 5 files changed, 128 insertions(+), 5 deletions(-) create mode 100644 fuzz/fuzz_targets/fuzz_target_3.rs diff --git a/Cargo.toml b/Cargo.toml index f10b1c4..4335d9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "varint-simd" version = "0.4.0" authors = ["Andrew Sun "] -edition = "2018" +edition = "2021" license = "MIT OR Apache-2.0" description = "SIMD-accelerated varint encoder and decoder" repository = "https://github.com/as-com/varint-simd" @@ -23,14 +23,14 @@ native-optimizations = [] dangerously-force-enable-pdep-since-i-really-know-what-im-doing = [] [dev-dependencies] -criterion = "0.3" -integer-encoding = "2.1" +criterion = "0.5" +integer-encoding = "4.0" rand = "0.8" bytes = "1" # prost-varint lazy_static = "1.4.0" [build-dependencies] -rustc_version = "0.3.0" +rustc_version = "0.4.0" [[bench]] name = "varint_bench" diff --git a/benches/varint_bench/main.rs b/benches/varint_bench/main.rs index 067b89a..74c6ceb 100644 --- a/benches/varint_bench/main.rs +++ b/benches/varint_bench/main.rs @@ -1,4 +1,3 @@ -use bytes::Buf; use criterion::{criterion_group, criterion_main, BatchSize, Criterion, Throughput}; use integer_encoding::VarInt; use rand::distributions::{Distribution, Standard}; @@ -7,6 +6,8 @@ use varint_simd::{ decode, decode_eight_u8_unsafe, decode_four_unsafe, + decode_len, + decode_len_unsafe, decode_two_unsafe, //decode_two_wide_unsafe, decode_unsafe, encode, @@ -37,6 +38,32 @@ where } } +#[inline(always)] +fn decode_len_batched_varint_simd(input: &mut (Vec, Vec)) { + let data = &input.0; + + let mut slice = &data[..]; + for _ in 0..C { + // SAFETY: the input slice should have at least 16 bytes of allocated padding at the end + let len = decode_len::(slice).unwrap(); + slice = &slice[len..]; + } +} + +#[inline(always)] +fn decode_len_batched_varint_simd_unsafe( + input: &mut (Vec, Vec), +) { + let data = &input.0; + + let mut slice = &data[..]; + for _ in 0..C { + // SAFETY: the input slice should have at least 16 bytes of allocated padding at the end + let len = unsafe { decode_len_unsafe::(slice.as_ptr()) }; + slice = &slice[len..]; + } +} + #[inline(always)] fn decode_batched_varint_simd_unsafe( input: &mut (Vec, Vec), @@ -215,6 +242,7 @@ pub fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("varint-u8/decode"); group.throughput(Throughput::Elements(SEQUENCE_LEN as u64)); + group.bench_function("integer-encoding", |b| { b.iter_batched_ref( create_batched_encoded_generator::(&mut rng), @@ -280,6 +308,26 @@ pub fn criterion_benchmark(c: &mut Criterion) { }); group.finish(); + let mut group = c.benchmark_group("varint-u8/decode_len"); + group.throughput(Throughput::Elements(SEQUENCE_LEN as u64)); + group.bench_function("varint-simd/unsafe", |b| { + b.iter_batched_ref( + create_batched_encoded_generator::(&mut rng), + decode_len_batched_varint_simd_unsafe::, + BatchSize::SmallInput, + ) + }); + + group.bench_function("varint-simd/safe", |b| { + b.iter_batched_ref( + create_batched_encoded_generator::(&mut rng), + decode_len_batched_varint_simd::, + BatchSize::SmallInput, + ) + }); + + group.finish(); + let mut group = c.benchmark_group("varint-u8/encode"); group.throughput(Throughput::Elements(1)); group.bench_function("integer-encoding", |b| { diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index bf26af8..24d9e56 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -32,3 +32,9 @@ name = "fuzz_target_2" path = "fuzz_targets/fuzz_target_2.rs" test = false doc = false + +[[bin]] +name = "fuzz_target_3" +path = "fuzz_targets/fuzz_target_3.rs" +test = false +doc = false diff --git a/fuzz/fuzz_targets/fuzz_target_3.rs b/fuzz/fuzz_targets/fuzz_target_3.rs new file mode 100644 index 0000000..bcb18a6 --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_target_3.rs @@ -0,0 +1,14 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; + +use integer_encoding::VarInt; + +fuzz_target!(|data: [u8; 16]| { + let reference = u64::decode_var(&data); + + let len = unsafe { varint_simd::decode_len_unsafe::(data.as_ptr()) }; + + if let Some(reference) = reference { + assert_eq!(reference.1, len); + } +}); diff --git a/src/decode/mod.rs b/src/decode/mod.rs index fddf305..47805e5 100644 --- a/src/decode/mod.rs +++ b/src/decode/mod.rs @@ -51,6 +51,35 @@ pub fn decode(bytes: &[u8]) -> Result<(T, usize), VarIntDecodeE } } +/// Decodes only the length of a single variant from the input slice. +/// +/// # Examples +/// ``` +/// use varint_simd::{decode_len, VarIntDecodeError}; +/// +/// fn main() -> Result<(), VarIntDecodeError> { +/// let decoded = decode_len::(&[185, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])?; +/// assert_eq!(decoded, 2); +/// Ok(()) +/// } +/// ``` +#[inline] +pub fn decode_len(bytes: &[u8]) -> Result { + let result = if bytes.len() >= 16 { + unsafe { decode_len_unsafe::(bytes.as_ptr()) } + } else if !bytes.is_empty() { + let mut data = [0u8; 16]; + let len = min(16, bytes.len()); + // unsafe { core::ptr::copy_nonoverlapping(bytes.as_ptr(), data.as_mut_ptr(), len); } + data[..len].copy_from_slice(&bytes[..len]); + unsafe { decode_len_unsafe::(data.as_ptr()) } + } else { + return Err(VarIntDecodeError::NotEnoughBytes); + }; + + Ok(result) +} + /// Convenience function for decoding a single varint in ZigZag format from the input slice. /// See also: [`decode`] /// @@ -69,6 +98,32 @@ pub fn decode_zigzag(bytes: &[u8]) -> Result<(T, usize), decode::(bytes).map(|r| (r.0.unzigzag(), r.1)) } +/// Decodes the length of the next integer +/// +/// # Safety +/// Same as `decode_unsafe` +#[inline] +pub unsafe fn decode_len_unsafe(bytes: *const u8) -> usize { + if T::MAX_VARINT_BYTES <= 5 { + let b = bytes.cast::().read_unaligned(); + let msbs = !b & !0x7f7f7f7f7f7f7f7f; + let len = msbs.trailing_zeros() + 1; // in bits + (len / 8) as usize + } else { + let b0 = bytes.cast::().read_unaligned(); + let b1 = bytes.cast::().add(1).read_unaligned(); + + let msbs0 = !b0 & !0x7f7f7f7f7f7f7f7f; + let msbs1 = !b1 & !0x7f7f7f7f7f7f7f7f; + + let len0 = msbs0.trailing_zeros() + 1; + let len1 = msbs1.trailing_zeros() + 1; + + let len = if msbs0 == 0 { len1 + 64 } else { len0 }; + len as usize / 8 + } +} + /// Decodes a single varint from the input pointer. Returns a tuple containing the decoded number /// and the number of bytes read. ///