Skip to content

Commit

Permalink
add decode_len and decode_len_unsafe and update dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
JojiiOfficial committed Dec 28, 2023
1 parent c94d001 commit 0b9407e
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 5 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "varint-simd"
version = "0.4.0"
authors = ["Andrew Sun <[email protected]>"]
edition = "2018"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "SIMD-accelerated varint encoder and decoder"
repository = "https://github.com/as-com/varint-simd"
Expand All @@ -23,14 +23,14 @@ native-optimizations = []
dangerously-force-enable-pdep-since-i-really-know-what-im-doing = []

[dev-dependencies]
criterion = "0.3"
integer-encoding = "2.1"
criterion = "0.5"
integer-encoding = "4.0"
rand = "0.8"
bytes = "1" # prost-varint
lazy_static = "1.4.0"

[build-dependencies]
rustc_version = "0.3.0"
rustc_version = "0.4.0"

[[bench]]
name = "varint_bench"
Expand Down
50 changes: 49 additions & 1 deletion benches/varint_bench/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use bytes::Buf;
use criterion::{criterion_group, criterion_main, BatchSize, Criterion, Throughput};
use integer_encoding::VarInt;
use rand::distributions::{Distribution, Standard};
Expand All @@ -7,6 +6,8 @@ use varint_simd::{
decode,
decode_eight_u8_unsafe,
decode_four_unsafe,
decode_len,
decode_len_unsafe,
decode_two_unsafe, //decode_two_wide_unsafe,
decode_unsafe,
encode,
Expand Down Expand Up @@ -37,6 +38,32 @@ where
}
}

#[inline(always)]
fn decode_len_batched_varint_simd<T: VarIntTarget, const C: usize>(input: &mut (Vec<u8>, Vec<T>)) {
let data = &input.0;

let mut slice = &data[..];
for _ in 0..C {
// SAFETY: the input slice should have at least 16 bytes of allocated padding at the end
let len = decode_len::<T>(slice).unwrap();
slice = &slice[len..];
}
}

#[inline(always)]
fn decode_len_batched_varint_simd_unsafe<T: VarIntTarget, const C: usize>(
input: &mut (Vec<u8>, Vec<T>),
) {
let data = &input.0;

let mut slice = &data[..];
for _ in 0..C {
// SAFETY: the input slice should have at least 16 bytes of allocated padding at the end
let len = unsafe { decode_len_unsafe::<T>(slice.as_ptr()) };
slice = &slice[len..];
}
}

#[inline(always)]
fn decode_batched_varint_simd_unsafe<T: VarIntTarget, const C: usize>(
input: &mut (Vec<u8>, Vec<T>),
Expand Down Expand Up @@ -215,6 +242,7 @@ pub fn criterion_benchmark(c: &mut Criterion) {

let mut group = c.benchmark_group("varint-u8/decode");
group.throughput(Throughput::Elements(SEQUENCE_LEN as u64));

group.bench_function("integer-encoding", |b| {
b.iter_batched_ref(
create_batched_encoded_generator::<u8, _, SEQUENCE_LEN>(&mut rng),
Expand Down Expand Up @@ -280,6 +308,26 @@ pub fn criterion_benchmark(c: &mut Criterion) {
});
group.finish();

let mut group = c.benchmark_group("varint-u8/decode_len");
group.throughput(Throughput::Elements(SEQUENCE_LEN as u64));
group.bench_function("varint-simd/unsafe", |b| {
b.iter_batched_ref(
create_batched_encoded_generator::<u8, _, SEQUENCE_LEN>(&mut rng),
decode_len_batched_varint_simd_unsafe::<u8, SEQUENCE_LEN>,
BatchSize::SmallInput,
)
});

group.bench_function("varint-simd/safe", |b| {
b.iter_batched_ref(
create_batched_encoded_generator::<u8, _, SEQUENCE_LEN>(&mut rng),
decode_len_batched_varint_simd::<u8, SEQUENCE_LEN>,
BatchSize::SmallInput,
)
});

group.finish();

let mut group = c.benchmark_group("varint-u8/encode");
group.throughput(Throughput::Elements(1));
group.bench_function("integer-encoding", |b| {
Expand Down
6 changes: 6 additions & 0 deletions fuzz/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ name = "fuzz_target_2"
path = "fuzz_targets/fuzz_target_2.rs"
test = false
doc = false

[[bin]]
name = "fuzz_target_3"
path = "fuzz_targets/fuzz_target_3.rs"
test = false
doc = false
14 changes: 14 additions & 0 deletions fuzz/fuzz_targets/fuzz_target_3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#![no_main]
use libfuzzer_sys::fuzz_target;

use integer_encoding::VarInt;

fuzz_target!(|data: [u8; 16]| {
let reference = u64::decode_var(&data);

let len = unsafe { varint_simd::decode_len_unsafe::<u64>(data.as_ptr()) };

if let Some(reference) = reference {
assert_eq!(reference.1, len);
}
});
55 changes: 55 additions & 0 deletions src/decode/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,35 @@ pub fn decode<T: VarIntTarget>(bytes: &[u8]) -> Result<(T, usize), VarIntDecodeE
}
}

/// Decodes only the length of a single variant from the input slice.
///
/// # Examples
/// ```
/// use varint_simd::{decode_len, VarIntDecodeError};
///
/// fn main() -> Result<(), VarIntDecodeError> {
/// let decoded = decode_len::<u32>(&[185, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])?;
/// assert_eq!(decoded, 2);
/// Ok(())
/// }
/// ```
#[inline]
pub fn decode_len<T: VarIntTarget>(bytes: &[u8]) -> Result<usize, VarIntDecodeError> {
let result = if bytes.len() >= 16 {
unsafe { decode_len_unsafe::<T>(bytes.as_ptr()) }
} else if !bytes.is_empty() {
let mut data = [0u8; 16];
let len = min(16, bytes.len());
// unsafe { core::ptr::copy_nonoverlapping(bytes.as_ptr(), data.as_mut_ptr(), len); }
data[..len].copy_from_slice(&bytes[..len]);
unsafe { decode_len_unsafe::<T>(data.as_ptr()) }
} else {
return Err(VarIntDecodeError::NotEnoughBytes);
};

Ok(result)
}

/// Convenience function for decoding a single varint in ZigZag format from the input slice.
/// See also: [`decode`]
///
Expand All @@ -69,6 +98,32 @@ pub fn decode_zigzag<T: SignedVarIntTarget>(bytes: &[u8]) -> Result<(T, usize),
decode::<T::Unsigned>(bytes).map(|r| (r.0.unzigzag(), r.1))
}

/// Decodes the length of the next integer
///
/// # Safety
/// Same as `decode_unsafe`
#[inline]
pub unsafe fn decode_len_unsafe<T: VarIntTarget>(bytes: *const u8) -> usize {
if T::MAX_VARINT_BYTES <= 5 {
let b = bytes.cast::<u64>().read_unaligned();
let msbs = !b & !0x7f7f7f7f7f7f7f7f;
let len = msbs.trailing_zeros() + 1; // in bits
(len / 8) as usize
} else {
let b0 = bytes.cast::<u64>().read_unaligned();
let b1 = bytes.cast::<u64>().add(1).read_unaligned();

let msbs0 = !b0 & !0x7f7f7f7f7f7f7f7f;
let msbs1 = !b1 & !0x7f7f7f7f7f7f7f7f;

let len0 = msbs0.trailing_zeros() + 1;
let len1 = msbs1.trailing_zeros() + 1;

let len = if msbs0 == 0 { len1 + 64 } else { len0 };
len as usize / 8
}
}

/// Decodes a single varint from the input pointer. Returns a tuple containing the decoded number
/// and the number of bytes read.
///
Expand Down

0 comments on commit 0b9407e

Please sign in to comment.