diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ac7026d..72cfa723 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `array_subset::ArrayStoreBytesError`, `store_bytes`, and `store_bytes_unchecked` - Added experimental `zfp` codec implementation behind `zfp` feature flag (disabled by default) - Added experimental `bitround` codec implementation behind `bitround` feature flag (disabled by default) + - This is similar to [numcodecs BitRound](https://numcodecs.readthedocs.io/en/stable/bitround.html#numcodecs.bitround.BitRound), but it supports rounding integers from the most significant set bit - Added `ShardingCodecBuilder` - Added `ReadableListableStorage`, `ReadableListableStorageTraits`, `StorageTransformerExtension::create_readable_listable_transformer` - Added `ByteRange::to_range()` and `to_range_usize()` diff --git a/src/array/codec/array_to_array/bitround.rs b/src/array/codec/array_to_array/bitround.rs index b9eee263..5c2d6ea4 100644 --- a/src/array/codec/array_to_array/bitround.rs +++ b/src/array/codec/array_to_array/bitround.rs @@ -1,6 +1,7 @@ //! The bitround array to array codec. //! //! Rounds the mantissa of floating point data types to the specified number of bits. +//! Rounds integers to the specified number of bits from the most significant set bit. //! //! This codec requires the `bitround` feature, which is disabled by default. //! @@ -43,8 +44,7 @@ fn create_codec_bitround(metadata: &Metadata) -> Result<Codec, PluginCreateError Ok(Codec::ArrayToArray(codec)) } -fn round_bits16(mut input: u16, keepbits: u32) -> u16 { - let maxbits = 10; +fn round_bits16(mut input: u16, keepbits: u32, maxbits: u32) -> u16 { if keepbits >= maxbits { input } else { @@ -58,8 +58,7 @@ fn round_bits16(mut input: u16, keepbits: u32) -> u16 { } } -fn round_bits32(mut input: u32, keepbits: u32) -> u32 { - let maxbits = 23; +fn round_bits32(mut input: u32, keepbits: u32, maxbits: u32) -> u32 { if keepbits >= maxbits { input } else { @@ -73,8 +72,7 @@ fn round_bits32(mut input: u32, keepbits: u32) -> u32 { } } -fn round_bits64(mut input: u64, keepbits: u32) -> u64 { - let maxbits = 52; +fn round_bits64(mut input: u64, keepbits: u32, maxbits: u32) -> u64 { if keepbits >= maxbits { input } else { @@ -93,7 +91,20 @@ fn round_bytes(bytes: &mut [u8], data_type: &DataType, keepbits: u32) -> Result< DataType::Float16 | DataType::BFloat16 => { let round = |chunk: &mut [u8]| { let element = u16::from_ne_bytes(chunk.try_into().unwrap()); - let element = u16::to_ne_bytes(round_bits16(element, keepbits)); + let element = u16::to_ne_bytes(round_bits16(element, keepbits, 10)); + chunk.copy_from_slice(&element); + }; + bytes.chunks_exact_mut(2).for_each(round); + Ok(()) + } + DataType::UInt16 | DataType::Int16 => { + let round = |chunk: &mut [u8]| { + let element = u16::from_ne_bytes(chunk.try_into().unwrap()); + let element = u16::to_ne_bytes(round_bits16( + element, + keepbits, + 16 - element.leading_zeros(), + )); chunk.copy_from_slice(&element); }; bytes.chunks_exact_mut(2).for_each(round); @@ -102,7 +113,20 @@ fn round_bytes(bytes: &mut [u8], data_type: &DataType, keepbits: u32) -> Result< DataType::Float32 | DataType::Complex64 => { let round = |chunk: &mut [u8]| { let element = u32::from_ne_bytes(chunk.try_into().unwrap()); - let element = u32::to_ne_bytes(round_bits32(element, keepbits)); + let element = u32::to_ne_bytes(round_bits32(element, keepbits, 23)); + chunk.copy_from_slice(&element); + }; + bytes.chunks_exact_mut(4).for_each(round); + Ok(()) + } + DataType::UInt32 | DataType::Int32 => { + let round = |chunk: &mut [u8]| { + let element = u32::from_ne_bytes(chunk.try_into().unwrap()); + let element = u32::to_ne_bytes(round_bits32( + element, + keepbits, + 32 - element.leading_zeros(), + )); chunk.copy_from_slice(&element); }; bytes.chunks_exact_mut(4).for_each(round); @@ -111,7 +135,20 @@ fn round_bytes(bytes: &mut [u8], data_type: &DataType, keepbits: u32) -> Result< DataType::Float64 | DataType::Complex128 => { let round = |chunk: &mut [u8]| { let element = u64::from_ne_bytes(chunk.try_into().unwrap()); - let element = u64::to_ne_bytes(round_bits64(element, keepbits)); + let element = u64::to_ne_bytes(round_bits64(element, keepbits, 52)); + chunk.copy_from_slice(&element); + }; + bytes.chunks_exact_mut(8).for_each(round); + Ok(()) + } + DataType::UInt64 | DataType::Int64 => { + let round = |chunk: &mut [u8]| { + let element = u64::from_ne_bytes(chunk.try_into().unwrap()); + let element = u64::to_ne_bytes(round_bits64( + element, + keepbits, + 64 - element.leading_zeros(), + )); chunk.copy_from_slice(&element); }; bytes.chunks_exact_mut(8).for_each(round); @@ -176,6 +213,30 @@ mod tests { assert_eq!(decoded_elements, &[0.0f32, 1.25f32, -8.0f32, 98304.0f32]); } + #[test] + fn codec_bitround_uint() { + const JSON: &'static str = r#"{ "keepbits": 3 }"#; + let array_representation = + ArrayRepresentation::new(vec![4], DataType::UInt32, 0u32.into()).unwrap(); + let elements: Vec<u32> = vec![0, 1024, 1280, 1664, 1685, 123145182]; + let bytes = safe_transmute::transmute_to_bytes(&elements).to_vec(); + + let codec_configuration: BitroundCodecConfiguration = serde_json::from_str(JSON).unwrap(); + let codec = BitroundCodec::new_with_configuration(&codec_configuration); + + let encoded = codec.encode(bytes.clone(), &array_representation).unwrap(); + let decoded = codec + .decode(encoded.clone(), &array_representation) + .unwrap(); + let decoded_elements = safe_transmute::transmute_many_permissive::<u32>(&decoded) + .unwrap() + .to_vec(); + for element in &decoded_elements { + println!("{element} -> {element:#b}"); + } + assert_eq!(decoded_elements, &[0, 1024, 1280, 1536, 1792, 117440512]); + } + #[test] fn codec_bitround_partial_decode() { const JSON: &'static str = r#"{ "keepbits": 2 }"#;