Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

read with padding and specialize u8 #2

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 165 additions & 12 deletions src/impls/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,122 @@ use core::convert::TryInto;
#[cfg(feature = "alloc")]
use alloc::format;

// TODO: duplicate code, separate macros for bits vs bytes? we only want to specialize reading a byte u8...
impl DekuRead<'_, (Endian, BitSize)> for u8 {
fn read(
input: &BitSlice<Msb0, u8>,
(endian, size): (Endian, BitSize),
) -> Result<(&BitSlice<Msb0, u8>, Self), DekuError> {
let max_type_bits: usize = BitSize::of::<u8>().0;
let bit_size: usize = size.0;

let input_is_le = endian.is_le();

if bit_size > max_type_bits {
return Err(DekuError::Parse(format!(
"too much data: container of {} bits cannot hold {} bits",
max_type_bits, bit_size
)));
}

if input.len() < bit_size {
return Err(DekuError::Incomplete(crate::error::NeedSize::new(bit_size)));
}

let (bit_slice, rest) = input.split_at(bit_size);

let pad = 8 * ((bit_slice.len() + 7) / 8) - bit_slice.len();

let value = if pad == 0
&& bit_slice.len() == max_type_bits
&& bit_slice.as_raw_slice().len() * 8 == max_type_bits
{
// if everything is aligned, just read the value
let bytes: &[u8] = bit_slice.as_raw_slice();

// Read value
if input_is_le {
<u8>::from_le_bytes(bytes.try_into()?)
} else {
<u8>::from_be_bytes(bytes.try_into()?)
}
} else {
// Create a new BitVec from the slice and pad un-aligned chunks
// i.e. [10010110, 1110] -> [10010110, 00001110]
let bits: BitVec<Msb0, u8> = {
let mut bits = BitVec::with_capacity(bit_slice.len() + pad);

// Copy bits to new BitVec
bits.extend_from_bitslice(bit_slice);

// Force align
//i.e. [1110, 10010110] -> [11101001, 0110]
bits.force_align();

// Some padding to next byte
let index = if input_is_le {
bits.len() - (8 - pad)
} else {
0
};
for _ in 0..pad {
bits.insert(index, false);
}

// Pad up-to size of type
for _ in 0..(max_type_bits - bits.len()) {
if input_is_le {
bits.push(false);
} else {
bits.insert(0, false);
}
}

bits
};

let bytes: &[u8] = bits.as_raw_slice();

// Read value
if input_is_le {
<u8>::from_le_bytes(bytes.try_into()?)
} else {
<u8>::from_be_bytes(bytes.try_into()?)
}
};
Ok((rest, value))
}
}

// specialize u8
impl DekuRead<'_, (Endian, ByteSize)> for u8 {
fn read(
input: &BitSlice<Msb0, u8>,
(_, size): (Endian, ByteSize),
) -> Result<(&BitSlice<Msb0, u8>, Self), DekuError> {
let max_type_bits: usize = BitSize::of::<u8>().0;
let bit_size: usize = size.0 * 8;

if bit_size > max_type_bits {
return Err(DekuError::Parse(format!(
"too much data: container of {} bits cannot hold {} bits",
max_type_bits, bit_size
)));
}

if input.len() < bit_size {
return Err(DekuError::Incomplete(crate::error::NeedSize::new(bit_size)));
}

let (bit_slice, rest) = input.split_at(bit_size);
let bytes: &[u8] = bit_slice.as_raw_slice();

Ok((rest, bytes[0]))
}
}

macro_rules! ImplDekuRead {
($typ:ty) => {
($typ:ty, $inner:ty) => {
impl DekuRead<'_, (Endian, BitSize)> for $typ {
fn read(
input: &BitSlice<Msb0, u8>,
Expand Down Expand Up @@ -116,14 +230,41 @@ macro_rules! ImplDekuRead {

let (bit_slice, rest) = input.split_at(bit_size);

let bytes: &[u8] = bit_slice.as_raw_slice();
let pad = 8 * ((bit_slice.len() + 7) / 8) - bit_slice.len();

// TODO: mention that this is slow? a conditional should be fast...
Copy link
Author

@sharksforarms sharksforarms Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi, I don't think this path is slow (it's just comparing numbers), I think it's more-so the allocation that we save with the specialization of bytes (try running test_alloc with the changes in this branch on master, we save an alloc + realloc)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

let value = if pad == 0
&& bit_slice.len() == max_type_bits
&& bit_slice.as_raw_slice().len() * 8 == max_type_bits
{
// if everything is aligned, just read the value
let bytes: &[u8] = bit_slice.as_raw_slice();

// Read value
let value = if input_is_le {
<$typ>::from_le_bytes(bytes.try_into()?)
// Read value
if input_is_le {
<$typ>::from_le_bytes(bytes.try_into()?)
} else {
<$typ>::from_be_bytes(bytes.try_into()?)
}
} else {
<$typ>::from_be_bytes(bytes.try_into()?)
let bytes: &[u8] = bit_slice.as_raw_slice();

// cannot use from_X_bytes as we don't have enough bytes for $typ
// read manually
let mut res: $inner = 0;
for b in bytes.iter().rev() {
res <<= 8 as $inner;
res |= *b as $inner;
}

if input_is_le {
res as $typ
} else {
res = res.swap_bytes();
res as $typ
}
};

Ok((rest, value))
}
}
Expand Down Expand Up @@ -177,9 +318,9 @@ macro_rules! ForwardDekuRead {

// Since we don't have a #[bits] or [bytes], check if we can use bytes for perf
if (bit_size.0 % 8) == 0 {
<$typ>::read(input, (endian, bit_size))
} else {
<$typ>::read(input, (endian, ByteSize(bit_size.0 / 8)))
} else {
<$typ>::read(input, (endian, bit_size))
}
}
}
Expand Down Expand Up @@ -363,7 +504,14 @@ macro_rules! ForwardDekuWrite {

macro_rules! ImplDekuTraits {
($typ:ty) => {
ImplDekuRead!($typ);
ImplDekuRead!($typ, $typ);
ForwardDekuRead!($typ);

ImplDekuWrite!($typ);
ForwardDekuWrite!($typ);
};
($typ:ty, $inner:ty) => {
ImplDekuRead!($typ, $inner);
ForwardDekuRead!($typ);

ImplDekuWrite!($typ);
Expand All @@ -381,7 +529,12 @@ macro_rules! ImplDekuTraitsSignExtend {
};
}

ImplDekuTraits!(u8);
// ImplDekuTraits!(u8);
// TODO: separate macros for bits vs bytes? we only want to specialize reading a byte u8...
ForwardDekuRead!(u8);
ImplDekuWrite!(u8);
ForwardDekuWrite!(u8);

ImplDekuTraits!(u16);
ImplDekuTraits!(u32);
ImplDekuTraits!(u64);
Expand All @@ -395,8 +548,8 @@ ImplDekuTraitsSignExtend!(i64, u64);
ImplDekuTraitsSignExtend!(i128, u128);
ImplDekuTraitsSignExtend!(isize, usize);

ImplDekuTraits!(f32);
ImplDekuTraits!(f64);
ImplDekuTraits!(f32, u32);
ImplDekuTraits!(f64, u64);

#[cfg(test)]
mod tests {
Expand Down
10 changes: 9 additions & 1 deletion tests/test_alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ enum NestedEnum {
VarA(u8),
}

#[derive(Debug, PartialEq, DekuRead, DekuWrite)]
#[deku(type = "u32", bytes = "2", ctx = "_endian: Endian")]
enum NestedEnum2 {
#[deku(id = "0x01")]
VarA(u8),
}

#[derive(Debug, PartialEq, DekuRead, DekuWrite)]
#[deku(endian = "big")]
struct TestDeku {
Expand All @@ -34,6 +41,7 @@ struct TestDeku {
field_g: u8, // 1 alloc (bits read)
#[deku(bits = "5")]
field_h: u8, // 1 alloc (bits read)
field_i: NestedEnum2,
}

mod tests {
Expand All @@ -45,7 +53,7 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)]
fn test_simple() {
let input = hex!("aabbbbcc0102ddffffffaa");
let input = hex!("aa_bbbb_cc_0102_dd_ffffff_aa_0100ff");

assert_eq!(
count_alloc(|| {
Expand Down