From 9979cc33bb964b4ee36898773a01f546d2c6487a Mon Sep 17 00:00:00 2001 From: Marshall Pierce <575695+marshallpierce@users.noreply.github.com> Date: Wed, 28 Feb 2024 07:25:06 -0700 Subject: [PATCH 1/6] Keep morsels as separate bytes ~6% speedup on decode_slice/3 --- src/engine/general_purpose/decode_suffix.rs | 37 ++++++++++----------- src/lib.rs | 1 - 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/engine/general_purpose/decode_suffix.rs b/src/engine/general_purpose/decode_suffix.rs index e1e005d..6f4a10e 100644 --- a/src/engine/general_purpose/decode_suffix.rs +++ b/src/engine/general_purpose/decode_suffix.rs @@ -17,14 +17,14 @@ pub(crate) fn decode_suffix( decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result { - // Decode any leftovers that aren't a complete input block of 8 bytes. + // Decode any leftovers that might not be a complete input chunk of 8 bytes. // Use a u64 as a stack-resident 8 byte buffer. - let mut leftover_bits: u64 = 0; let mut morsels_in_leftover = 0; let mut padding_bytes = 0; let mut first_padding_index: usize = 0; let mut last_symbol = 0_u8; let start_of_leftovers = input_index; + let mut morsels = [0_u8; 8]; for (i, &b) in input[start_of_leftovers..].iter().enumerate() { // '=' padding @@ -83,13 +83,12 @@ pub(crate) fn decode_suffix( // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding. // Pack the leftovers from left to right. - let shift = 64 - (morsels_in_leftover + 1) * 6; let morsel = decode_table[b as usize]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte(start_of_leftovers + i, b)); } - leftover_bits |= (morsel as u64) << shift; + morsels[morsels_in_leftover] = morsel; morsels_in_leftover += 1; } @@ -121,23 +120,23 @@ pub(crate) fn decode_suffix( // useless since there are no more symbols to provide the necessary 4 additional bits // to finish the second original byte. - let leftover_bits_ready_to_append = match morsels_in_leftover { - 0 => 0, - 2 => 8, - 3 => 16, - 4 => 24, - 6 => 32, - 7 => 40, - 8 => 48, - // can also be detected as case #2 bad padding above - _ => unreachable!( - "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths" - ), - }; + // TODO how do we know this? + debug_assert!(morsels_in_leftover != 1 && morsels_in_leftover != 5); + let leftover_bytes_to_append = morsels_in_leftover * 6 / 8; + let leftover_bits_to_append = leftover_bytes_to_append * 8; + // A couple percent speedup from nudging these ORs to use more ILP with a two-way split + let leftover_bits = ((u64::from(morsels[0]) << 58) + | (u64::from(morsels[1]) << 52) + | (u64::from(morsels[2]) << 46) + | (u64::from(morsels[3]) << 40)) + | ((u64::from(morsels[4]) << 34) + | (u64::from(morsels[5]) << 28) + | (u64::from(morsels[6]) << 22) + | (u64::from(morsels[7]) << 16)); // if there are bits set outside the bits we care about, last symbol encodes trailing bits that // will not be included in the output - let mask = !0 >> leftover_bits_ready_to_append; + let mask = !0 >> leftover_bits_to_append; if !decode_allow_trailing_bits && (leftover_bits & mask) != 0 { // last morsel is at `morsels_in_leftover` - 1 return Err(DecodeError::InvalidLastSymbol( @@ -148,7 +147,7 @@ pub(crate) fn decode_suffix( // TODO benchmark simply converting to big endian bytes let mut leftover_bits_appended_to_buf = 0; - while leftover_bits_appended_to_buf < leftover_bits_ready_to_append { + while leftover_bits_appended_to_buf < leftover_bits_to_append { // `as` simply truncates the higher bits, which is what we want here let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8; output[output_index] = selected_bits; diff --git a/src/lib.rs b/src/lib.rs index 6ec3c12..6b5cccb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -229,7 +229,6 @@ unused_import_braces, unused_results, variant_size_differences, - warnings )] #![forbid(unsafe_code)] // Allow globally until https://github.com/rust-lang/rust-clippy/issues/8768 is resolved. From a25be0667c63460827cfadd71d1630acb442bb09 Mon Sep 17 00:00:00 2001 From: Marshall Pierce <575695+marshallpierce@users.noreply.github.com> Date: Wed, 28 Feb 2024 08:00:34 -0700 Subject: [PATCH 2/6] Simplify leftover output writes No perf impact --- src/engine/general_purpose/decode_suffix.rs | 24 ++++++++++----------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/engine/general_purpose/decode_suffix.rs b/src/engine/general_purpose/decode_suffix.rs index 6f4a10e..9fbb0d5 100644 --- a/src/engine/general_purpose/decode_suffix.rs +++ b/src/engine/general_purpose/decode_suffix.rs @@ -123,9 +123,9 @@ pub(crate) fn decode_suffix( // TODO how do we know this? debug_assert!(morsels_in_leftover != 1 && morsels_in_leftover != 5); let leftover_bytes_to_append = morsels_in_leftover * 6 / 8; - let leftover_bits_to_append = leftover_bytes_to_append * 8; - // A couple percent speedup from nudging these ORs to use more ILP with a two-way split - let leftover_bits = ((u64::from(morsels[0]) << 58) + // Put the up to 6 complete bytes as the high bytes. + // Gain a couple percent speedup from nudging these ORs to use more ILP with a two-way split. + let mut leftover_num = ((u64::from(morsels[0]) << 58) | (u64::from(morsels[1]) << 52) | (u64::from(morsels[2]) << 46) | (u64::from(morsels[3]) << 40)) @@ -136,8 +136,8 @@ pub(crate) fn decode_suffix( // if there are bits set outside the bits we care about, last symbol encodes trailing bits that // will not be included in the output - let mask = !0 >> leftover_bits_to_append; - if !decode_allow_trailing_bits && (leftover_bits & mask) != 0 { + let mask = !0 >> (leftover_bytes_to_append * 8); + if !decode_allow_trailing_bits && (leftover_num & mask) != 0 { // last morsel is at `morsels_in_leftover` - 1 return Err(DecodeError::InvalidLastSymbol( start_of_leftovers + morsels_in_leftover - 1, @@ -145,15 +145,13 @@ pub(crate) fn decode_suffix( )); } - // TODO benchmark simply converting to big endian bytes - let mut leftover_bits_appended_to_buf = 0; - while leftover_bits_appended_to_buf < leftover_bits_to_append { - // `as` simply truncates the higher bits, which is what we want here - let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8; - output[output_index] = selected_bits; + // Strangely, this approach benchmarks better than writing bytes one at a time, + // or copy_from_slice into output. + for _ in 0..leftover_bytes_to_append { + let hi_byte = (leftover_num >> 56) as u8; + leftover_num <<= 8; + output[output_index] = hi_byte; output_index += 1; - - leftover_bits_appended_to_buf += 8; } Ok(DecodeMetadata::new( From a8a60f43c56597259558261353b5bf7e953eed36 Mon Sep 17 00:00:00 2001 From: Marshall Pierce <575695+marshallpierce@users.noreply.github.com> Date: Fri, 1 Mar 2024 14:42:43 -0700 Subject: [PATCH 3/6] Decode main loop improvements - Rearrange main decoding loops to handle chunks of 32 bytes at a time, then 4 bytes at a time, meaning that `decode_suffix` need only handle 0-4 bytes, simplifying its code. Moderate speed gains of around 5-10%. - Improve error precision. `InvalidLength` now has a `usize` length indicating how many valid symbols were found, but that the count of those symbols was invalid. Before, it just did `input % 4 == `, which was harder to reason about, as there might be padding etc. DecoderReader now also precisely reports the suitable InvalidByte if an earlier block of decoding found padding that was valid in that context, but more padding was found later, rendering that earlier padding invalid. - Tidy up decode tests. There were some duplicated scenarios, and certain aspects are now tested in more detail. --- src/decode.rs | 26 +- src/engine/general_purpose/decode.rs | 366 +++++------- src/engine/general_purpose/decode_suffix.rs | 76 +-- src/engine/naive.rs | 13 +- src/engine/tests.rs | 625 +++++++------------- src/read/decoder.rs | 70 ++- 6 files changed, 466 insertions(+), 710 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index 5230fd3..0f66c74 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -9,18 +9,20 @@ use std::error; #[derive(Clone, Debug, PartialEq, Eq)] pub enum DecodeError { /// An invalid byte was found in the input. The offset and offending byte are provided. - /// Padding characters (`=`) interspersed in the encoded form will be treated as invalid bytes. + /// + /// Padding characters (`=`) interspersed in the encoded form are invalid, as they may only + /// be present as the last 0-2 bytes of input. + /// + /// This error may also indicate that extraneous trailing input bytes are present, causing + /// otherwise valid padding to no longer be the last bytes of input. InvalidByte(usize, u8), - /// The length of the input is invalid. - /// A typical cause of this is stray trailing whitespace or other separator bytes. - /// In the case where excess trailing bytes have produced an invalid length *and* the last byte - /// is also an invalid base64 symbol (as would be the case for whitespace, etc), `InvalidByte` - /// will be emitted instead of `InvalidLength` to make the issue easier to debug. - InvalidLength, + /// The length of the input, as measured in valid base64 symbols, is invalid. + /// There must be 2-4 symbols in the last input quad. + InvalidLength(usize), /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded. /// This is indicative of corrupted or truncated Base64. - /// Unlike `InvalidByte`, which reports symbols that aren't in the alphabet, this error is for - /// symbols that are in the alphabet but represent nonsensical encodings. + /// Unlike [DecodeError::InvalidByte], which reports symbols that aren't in the alphabet, + /// this error is for symbols that are in the alphabet but represent nonsensical encodings. InvalidLastSymbol(usize, u8), /// The nature of the padding was not as configured: absent or incorrect when it must be /// canonical, or present when it must be absent, etc. @@ -30,8 +32,10 @@ pub enum DecodeError { impl fmt::Display for DecodeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Self::InvalidByte(index, byte) => write!(f, "Invalid byte {}, offset {}.", byte, index), - Self::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."), + Self::InvalidByte(index, byte) => { + write!(f, "Invalid symbol {}, offset {}.", byte, index) + } + Self::InvalidLength(len) => write!(f, "Invalid input length: {}", len), Self::InvalidLastSymbol(index, byte) => { write!(f, "Invalid last symbol {}, offset {}.", byte, index) } diff --git a/src/engine/general_purpose/decode.rs b/src/engine/general_purpose/decode.rs index 21a386f..31c289e 100644 --- a/src/engine/general_purpose/decode.rs +++ b/src/engine/general_purpose/decode.rs @@ -3,45 +3,25 @@ use crate::{ DecodeError, PAD_BYTE, }; -// decode logic operates on chunks of 8 input bytes without padding -const INPUT_CHUNK_LEN: usize = 8; -const DECODED_CHUNK_LEN: usize = 6; - -// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last -// 2 bytes of any output u64 should not be counted as written to (but must be available in a -// slice). -const DECODED_CHUNK_SUFFIX: usize = 2; - -// how many u64's of input to handle at a time -const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4; - -const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN; - -// includes the trailing 2 bytes for the final u64 write -const DECODED_BLOCK_LEN: usize = - CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX; - #[doc(hidden)] pub struct GeneralPurposeEstimate { - /// Total number of decode chunks, including a possibly partial last chunk - num_chunks: usize, - decoded_len_estimate: usize, + rem: usize, + conservative_len: usize, } impl GeneralPurposeEstimate { pub(crate) fn new(encoded_len: usize) -> Self { - // Formulas that won't overflow + let rem = encoded_len % 4; Self { - num_chunks: encoded_len / INPUT_CHUNK_LEN - + (encoded_len % INPUT_CHUNK_LEN > 0) as usize, - decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3, + rem, + conservative_len: (encoded_len / 4 + (rem > 0) as usize) * 3, } } } impl DecodeEstimate for GeneralPurposeEstimate { fn decoded_len_estimate(&self) -> usize { - self.decoded_len_estimate + self.conservative_len } } @@ -59,264 +39,237 @@ pub(crate) fn decode_helper( decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result { - let remainder_len = input.len() % INPUT_CHUNK_LEN; - - // Because the fast decode loop writes in groups of 8 bytes (unrolled to - // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of - // which only 6 are valid data), we need to be sure that we stop using the fast decode loop - // soon enough that there will always be 2 more bytes of valid data written after that loop. - let trailing_bytes_to_skip = match remainder_len { - // if input is a multiple of the chunk size, ignore the last chunk as it may have padding, - // and the fast decode logic cannot handle padding - 0 => INPUT_CHUNK_LEN, - // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte - 1 | 5 => { - // trailing whitespace is so common that it's worth it to check the last byte to - // possibly return a better error message - if let Some(b) = input.last() { - if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE { - return Err(DecodeError::InvalidByte(input.len() - 1, *b)); - } - } - - return Err(DecodeError::InvalidLength); + // detect a trailing invalid byte, like a newline, as a user convenience + if estimate.rem == 1 { + let last_byte = input[input.len() - 1]; + // exclude pad bytes; might be part of padding that extends from earlier in the input + if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE { + return Err(DecodeError::InvalidByte(input.len() - 1, last_byte)); } - // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes - // written by the fast decode loop. So, we have to ignore both these 2 bytes and the - // previous chunk. - 2 => INPUT_CHUNK_LEN + 2, - // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this - // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail - // with an error, not panic from going past the bounds of the output slice, so we let it - // use stage 3 + 4. - 3 => INPUT_CHUNK_LEN + 3, - // This can also decode to one output byte because it may be 2 input chars + 2 padding - // chars, which would decode to 1 byte. - 4 => INPUT_CHUNK_LEN + 4, - // Everything else is a legal decode len (given that we don't require padding), and will - // decode to at least 2 bytes of output. - _ => remainder_len, - }; - - // rounded up to include partial chunks - let mut remaining_chunks = estimate.num_chunks; - - let mut input_index = 0; - let mut output_index = 0; + } + // skip last quad, even if it's complete, as it may have padding + let input_complete_nonterminal_quads_len = input + .len() + .saturating_sub(estimate.rem) + // if rem was 0, subtract 4 to avoid padding + .saturating_sub((estimate.rem == 0) as usize * 4); + debug_assert!( + input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len)) + ); + + const UNROLLED_INPUT_CHUNK_SIZE: usize = 32; + const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3; + + let input_complete_quads_after_unrolled_chunks_len = + input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE; + + let input_unrolled_loop_len = + input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len; + + // chunks of 32 bytes + for (chunk_index, chunk) in input[..input_unrolled_loop_len] + .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE) + .enumerate() { - let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip); - - // Fast loop, stage 1 - // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks - if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) { - while input_index <= max_start_index { - let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)]; - let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)]; - - decode_chunk( - &input_slice[0..], - input_index, - decode_table, - &mut output_slice[0..], - )?; - decode_chunk( - &input_slice[8..], - input_index + 8, - decode_table, - &mut output_slice[6..], - )?; - decode_chunk( - &input_slice[16..], - input_index + 16, - decode_table, - &mut output_slice[12..], - )?; - decode_chunk( - &input_slice[24..], - input_index + 24, - decode_table, - &mut output_slice[18..], - )?; - - input_index += INPUT_BLOCK_LEN; - output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX; - remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK; - } - } - - // Fast loop, stage 2 (aka still pretty fast loop) - // 8 bytes at a time for whatever we didn't do in stage 1. - if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) { - while input_index < max_start_index { - decode_chunk( - &input[input_index..(input_index + INPUT_CHUNK_LEN)], - input_index, - decode_table, - &mut output - [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)], - )?; - - output_index += DECODED_CHUNK_LEN; - input_index += INPUT_CHUNK_LEN; - remaining_chunks -= 1; - } - } - } + let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE; + let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE + ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE]; - // Stage 3 - // If input length was such that a chunk had to be deferred until after the fast loop - // because decoding it would have produced 2 trailing bytes that wouldn't then be - // overwritten, we decode that chunk here. This way is slower but doesn't write the 2 - // trailing bytes. - // However, we still need to avoid the last chunk (partial or complete) because it could - // have padding, so we always do 1 fewer to avoid the last chunk. - for _ in 1..remaining_chunks { - decode_chunk_precise( - &input[input_index..], + decode_chunk_8( + &chunk[0..8], input_index, decode_table, - &mut output[output_index..(output_index + DECODED_CHUNK_LEN)], + &mut chunk_output[0..6], + )?; + decode_chunk_8( + &chunk[8..16], + input_index + 8, + decode_table, + &mut chunk_output[6..12], + )?; + decode_chunk_8( + &chunk[16..24], + input_index + 16, + decode_table, + &mut chunk_output[12..18], + )?; + decode_chunk_8( + &chunk[24..32], + input_index + 24, + decode_table, + &mut chunk_output[18..24], )?; - - input_index += INPUT_CHUNK_LEN; - output_index += DECODED_CHUNK_LEN; } - // always have one more (possibly partial) block of 8 input - debug_assert!(input.len() - input_index > 1 || input.is_empty()); - debug_assert!(input.len() - input_index <= 8); + // remaining quads, except for the last possibly partial one, as it may have padding + let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3; + let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3; + { + let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len]; + + for (chunk_index, chunk) in input + [input_unrolled_loop_len..input_complete_nonterminal_quads_len] + .chunks_exact(4) + .enumerate() + { + let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3]; + + decode_chunk_4( + chunk, + input_unrolled_loop_len + chunk_index * 4, + decode_table, + chunk_output, + )?; + } + } super::decode_suffix::decode_suffix( input, - input_index, + input_complete_nonterminal_quads_len, output, - output_index, + output_complete_quad_len, decode_table, decode_allow_trailing_bits, padding_mode, ) } -/// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the -/// first 6 of those contain meaningful data. +/// Decode 8 bytes of input into 6 bytes of output. /// -/// `input` is the bytes to decode, of which the first 8 bytes will be processed. +/// `input` is the 8 bytes to decode. /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors /// accurately) /// `decode_table` is the lookup table for the particular base64 alphabet. -/// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded -/// data. +/// `output` will have its first 6 bytes overwritten // yes, really inline (worth 30-50% speedup) #[inline(always)] -fn decode_chunk( +fn decode_chunk_8( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError> { - let morsel = decode_table[input[0] as usize]; + let morsel = decode_table[usize::from(input[0])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); } - let mut accum = (morsel as u64) << 58; + let mut accum = u64::from(morsel) << 58; - let morsel = decode_table[input[1] as usize]; + let morsel = decode_table[usize::from(input[1])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 1, input[1], )); } - accum |= (morsel as u64) << 52; + accum |= u64::from(morsel) << 52; - let morsel = decode_table[input[2] as usize]; + let morsel = decode_table[usize::from(input[2])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 2, input[2], )); } - accum |= (morsel as u64) << 46; + accum |= u64::from(morsel) << 46; - let morsel = decode_table[input[3] as usize]; + let morsel = decode_table[usize::from(input[3])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 3, input[3], )); } - accum |= (morsel as u64) << 40; + accum |= u64::from(morsel) << 40; - let morsel = decode_table[input[4] as usize]; + let morsel = decode_table[usize::from(input[4])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 4, input[4], )); } - accum |= (morsel as u64) << 34; + accum |= u64::from(morsel) << 34; - let morsel = decode_table[input[5] as usize]; + let morsel = decode_table[usize::from(input[5])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 5, input[5], )); } - accum |= (morsel as u64) << 28; + accum |= u64::from(morsel) << 28; - let morsel = decode_table[input[6] as usize]; + let morsel = decode_table[usize::from(input[6])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 6, input[6], )); } - accum |= (morsel as u64) << 22; + accum |= u64::from(morsel) << 22; - let morsel = decode_table[input[7] as usize]; + let morsel = decode_table[usize::from(input[7])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 7, input[7], )); } - accum |= (morsel as u64) << 16; + accum |= u64::from(morsel) << 16; - write_u64(output, accum); + output[..6].copy_from_slice(&accum.to_be_bytes()[..6]); Ok(()) } -/// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2 -/// trailing garbage bytes. -#[inline] -fn decode_chunk_precise( +/// Like [decode_chunk_8] but for 4 bytes of input and 3 bytes of output. +#[inline(always)] +fn decode_chunk_4( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError> { - let mut tmp_buf = [0_u8; 8]; + let morsel = decode_table[usize::from(input[0])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); + } + let mut accum = u32::from(morsel) << 26; - decode_chunk( - input, - index_at_start_of_input, - decode_table, - &mut tmp_buf[..], - )?; + let morsel = decode_table[usize::from(input[1])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 1, + input[1], + )); + } + accum |= u32::from(morsel) << 20; + + let morsel = decode_table[usize::from(input[2])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 2, + input[2], + )); + } + accum |= u32::from(morsel) << 14; + + let morsel = decode_table[usize::from(input[3])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 3, + input[3], + )); + } + accum |= u32::from(morsel) << 8; - output[0..6].copy_from_slice(&tmp_buf[0..6]); + output[..3].copy_from_slice(&accum.to_be_bytes()[..3]); Ok(()) } -#[inline] -fn write_u64(output: &mut [u8], value: u64) { - output[..8].copy_from_slice(&value.to_be_bytes()); -} - #[cfg(test)] mod tests { use super::*; @@ -324,37 +277,36 @@ mod tests { use crate::engine::general_purpose::STANDARD; #[test] - fn decode_chunk_precise_writes_only_6_bytes() { + fn decode_chunk_8_writes_only_6_bytes() { let input = b"Zm9vYmFy"; // "foobar" let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; - decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output); } #[test] - fn decode_chunk_writes_8_bytes() { - let input = b"Zm9vYmFy"; // "foobar" - let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; + fn decode_chunk_4_writes_only_3_bytes() { + let input = b"Zm9v"; // "foobar" + let mut output = [0_u8, 1, 2, 3]; - decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); - assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output); + decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + assert_eq!(&vec![b'f', b'o', b'o', 3], &output); } #[test] fn estimate_short_lengths() { - for (range, (num_chunks, decoded_len_estimate)) in [ - (0..=0, (0, 0)), - (1..=4, (1, 3)), - (5..=8, (1, 6)), - (9..=12, (2, 9)), - (13..=16, (2, 12)), - (17..=20, (3, 15)), + for (range, decoded_len_estimate) in [ + (0..=0, 0), + (1..=4, 3), + (5..=8, 6), + (9..=12, 9), + (13..=16, 12), + (17..=20, 15), ] { for encoded_len in range { let estimate = GeneralPurposeEstimate::new(encoded_len); - assert_eq!(num_chunks, estimate.num_chunks); - assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate); + assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate()); } } } @@ -369,15 +321,7 @@ mod tests { let len_128 = encoded_len as u128; let estimate = GeneralPurposeEstimate::new(encoded_len); - assert_eq!( - ((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128)) - as usize, - estimate.num_chunks - ); - assert_eq!( - ((len_128 + 3) / 4 * 3) as usize, - estimate.decoded_len_estimate - ); + assert_eq!((len_128 + 3) / 4 * 3, estimate.conservative_len as u128); }) } } diff --git a/src/engine/general_purpose/decode_suffix.rs b/src/engine/general_purpose/decode_suffix.rs index 9fbb0d5..3d52ae5 100644 --- a/src/engine/general_purpose/decode_suffix.rs +++ b/src/engine/general_purpose/decode_suffix.rs @@ -3,7 +3,7 @@ use crate::{ DecodeError, PAD_BYTE, }; -/// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided +/// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided /// parameters. /// /// Returns the decode metadata representing the total number of bytes decoded, including the ones @@ -17,16 +17,18 @@ pub(crate) fn decode_suffix( decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result { + debug_assert!((input.len() - input_index) <= 4); + // Decode any leftovers that might not be a complete input chunk of 8 bytes. // Use a u64 as a stack-resident 8 byte buffer. let mut morsels_in_leftover = 0; - let mut padding_bytes = 0; - let mut first_padding_index: usize = 0; + let mut padding_bytes_count = 0; + // offset from input_index + let mut first_padding_offset: usize = 0; let mut last_symbol = 0_u8; - let start_of_leftovers = input_index; - let mut morsels = [0_u8; 8]; + let mut morsels = [0_u8; 4]; - for (i, &b) in input[start_of_leftovers..].iter().enumerate() { + for (leftover_index, &b) in input[input_index..].iter().enumerate() { // '=' padding if b == PAD_BYTE { // There can be bad padding bytes in a few ways: @@ -41,30 +43,30 @@ pub(crate) fn decode_suffix( // Per config, non-canonical but still functional non- or partially-padded base64 // may be treated as an error condition. - if i % 4 < 2 { + if leftover_index < 2 { // Check for case #2. - let bad_padding_index = start_of_leftovers - + if padding_bytes > 0 { + let bad_padding_index = input_index + + if padding_bytes_count > 0 { // If we've already seen padding, report the first padding index. // This is to be consistent with the normal decode logic: it will report an // error on the first padding character (since it doesn't expect to see // anything but actual encoded data). // This could only happen if the padding started in the previous quad since - // otherwise this case would have been hit at i % 4 == 0 if it was the same + // otherwise this case would have been hit at i == 4 if it was the same // quad. - first_padding_index + first_padding_offset } else { // haven't seen padding before, just use where we are now - i + leftover_index }; return Err(DecodeError::InvalidByte(bad_padding_index, b)); } - if padding_bytes == 0 { - first_padding_index = i; + if padding_bytes_count == 0 { + first_padding_offset = leftover_index; } - padding_bytes += 1; + padding_bytes_count += 1; continue; } @@ -72,9 +74,9 @@ pub(crate) fn decode_suffix( // To make '=' handling consistent with the main loop, don't allow // non-suffix '=' in trailing chunk either. Report error as first // erroneous padding. - if padding_bytes > 0 { + if padding_bytes_count > 0 { return Err(DecodeError::InvalidByte( - start_of_leftovers + first_padding_index, + input_index + first_padding_offset, PAD_BYTE, )); } @@ -85,22 +87,31 @@ pub(crate) fn decode_suffix( // Pack the leftovers from left to right. let morsel = decode_table[b as usize]; if morsel == INVALID_VALUE { - return Err(DecodeError::InvalidByte(start_of_leftovers + i, b)); + return Err(DecodeError::InvalidByte(input_index + leftover_index, b)); } morsels[morsels_in_leftover] = morsel; morsels_in_leftover += 1; } + // If there was 1 trailing byte, and it was valid, and we got to this point without hitting + // an invalid byte, now we can report invalid length + if !input.is_empty() && morsels_in_leftover < 2 { + return Err(DecodeError::InvalidLength( + input_index + morsels_in_leftover, + )); + } + match padding_mode { DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ } DecodePaddingMode::RequireCanonical => { - if (padding_bytes + morsels_in_leftover) % 4 != 0 { + // allow empty input + if (padding_bytes_count + morsels_in_leftover) % 4 != 0 { return Err(DecodeError::InvalidPadding); } } DecodePaddingMode::RequireNone => { - if padding_bytes > 0 { + if padding_bytes_count > 0 { // check at the end to make sure we let the cases of padding that should be InvalidByte // get hit return Err(DecodeError::InvalidPadding); @@ -120,27 +131,21 @@ pub(crate) fn decode_suffix( // useless since there are no more symbols to provide the necessary 4 additional bits // to finish the second original byte. - // TODO how do we know this? - debug_assert!(morsels_in_leftover != 1 && morsels_in_leftover != 5); let leftover_bytes_to_append = morsels_in_leftover * 6 / 8; // Put the up to 6 complete bytes as the high bytes. // Gain a couple percent speedup from nudging these ORs to use more ILP with a two-way split. - let mut leftover_num = ((u64::from(morsels[0]) << 58) - | (u64::from(morsels[1]) << 52) - | (u64::from(morsels[2]) << 46) - | (u64::from(morsels[3]) << 40)) - | ((u64::from(morsels[4]) << 34) - | (u64::from(morsels[5]) << 28) - | (u64::from(morsels[6]) << 22) - | (u64::from(morsels[7]) << 16)); + let mut leftover_num = (u32::from(morsels[0]) << 26) + | (u32::from(morsels[1]) << 20) + | (u32::from(morsels[2]) << 14) + | (u32::from(morsels[3]) << 8); // if there are bits set outside the bits we care about, last symbol encodes trailing bits that // will not be included in the output - let mask = !0 >> (leftover_bytes_to_append * 8); + let mask = !0_u32 >> (leftover_bytes_to_append * 8); if !decode_allow_trailing_bits && (leftover_num & mask) != 0 { // last morsel is at `morsels_in_leftover` - 1 return Err(DecodeError::InvalidLastSymbol( - start_of_leftovers + morsels_in_leftover - 1, + input_index + morsels_in_leftover - 1, last_symbol, )); } @@ -148,16 +153,17 @@ pub(crate) fn decode_suffix( // Strangely, this approach benchmarks better than writing bytes one at a time, // or copy_from_slice into output. for _ in 0..leftover_bytes_to_append { - let hi_byte = (leftover_num >> 56) as u8; + let hi_byte = (leftover_num >> 24) as u8; leftover_num <<= 8; + // TODO use checked writes output[output_index] = hi_byte; output_index += 1; } Ok(DecodeMetadata::new( output_index, - if padding_bytes > 0 { - Some(input_index + first_padding_index) + if padding_bytes_count > 0 { + Some(input_index + first_padding_offset) } else { None }, diff --git a/src/engine/naive.rs b/src/engine/naive.rs index 6a50cbe..2546a6f 100644 --- a/src/engine/naive.rs +++ b/src/engine/naive.rs @@ -115,15 +115,12 @@ impl Engine for Naive { if estimate.rem == 1 { // trailing whitespace is so common that it's worth it to check the last byte to // possibly return a better error message - if let Some(b) = input.last() { - if *b != PAD_BYTE - && self.decode_table[*b as usize] == general_purpose::INVALID_VALUE - { - return Err(DecodeError::InvalidByte(input.len() - 1, *b)); - } + let last_byte = input[input.len() - 1]; + if last_byte != PAD_BYTE + && self.decode_table[usize::from(last_byte)] == general_purpose::INVALID_VALUE + { + return Err(DecodeError::InvalidByte(input.len() - 1, last_byte)); } - - return Err(DecodeError::InvalidLength); } let mut input_index = 0_usize; diff --git a/src/engine/tests.rs b/src/engine/tests.rs index b048005..b73f108 100644 --- a/src/engine/tests.rs +++ b/src/engine/tests.rs @@ -365,26 +365,49 @@ fn decode_detect_invalid_last_symbol(engine_wrapper: E) { } #[apply(all_engines)] -fn decode_detect_invalid_last_symbol_when_length_is_also_invalid( - engine_wrapper: E, -) { - let mut rng = seeded_rng(); - - // check across enough lengths that it would likely cover any implementation's various internal - // small/large input division +fn decode_detect_1_valid_symbol_in_last_quad_invalid_length(engine_wrapper: E) { for len in (0_usize..256).map(|len| len * 4 + 1) { - let engine = E::random_alphabet(&mut rng, &STANDARD); + for mode in all_pad_modes() { + let mut input = vec![b'A'; len]; - let mut input = vec![b'A'; len]; + let engine = E::standard_with_pad_mode(true, mode); - // with a valid last char, it's InvalidLength - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&input)); - // after mangling the last char, it's InvalidByte - input[len - 1] = b'"'; - assert_eq!( - Err(DecodeError::InvalidByte(len - 1, b'"')), - engine.decode(&input) - ); + assert_eq!(Err(DecodeError::InvalidLength(len)), engine.decode(&input)); + // if we add padding, then the first pad byte in the quad is invalid because it should + // be the second symbol + for _ in 0..3 { + input.push(PAD_BYTE); + assert_eq!( + Err(DecodeError::InvalidByte(len, PAD_BYTE)), + engine.decode(&input) + ); + } + } + } +} + +#[apply(all_engines)] +fn decode_detect_1_invalid_byte_in_last_quad_invalid_byte(engine_wrapper: E) { + for prefix_len in (0_usize..256).map(|len| len * 4) { + for mode in all_pad_modes() { + let mut input = vec![b'A'; prefix_len]; + input.push(b'*'); + + let engine = E::standard_with_pad_mode(true, mode); + + assert_eq!( + Err(DecodeError::InvalidByte(prefix_len, b'*')), + engine.decode(&input) + ); + // adding padding doesn't matter + for _ in 0..3 { + input.push(PAD_BYTE); + assert_eq!( + Err(DecodeError::InvalidByte(prefix_len, b'*')), + engine.decode(&input) + ); + } + } } } @@ -471,8 +494,10 @@ fn decode_detect_invalid_last_symbol_every_possible_three_symbols(engine_wrapper: E) { /// Any amount of padding anywhere before the final non padding character = invalid byte at first /// pad byte. -/// From this, we know padding must extend to the end of the input. -// DecoderReader pseudo-engine detects InvalidLastSymbol instead of InvalidLength because it -// can end a decode on the quad that happens to contain the start of the padding -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_before_final_non_padding_char_error_invalid_byte( +/// From this and [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix_all_modes], +/// we know padding must extend contiguously to the end of the input. +#[apply(all_engines)] +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes< + E: EngineWrapper, +>( engine_wrapper: E, ) { - let mut rng = seeded_rng(); + // Different amounts of padding, w/ offset from end for the last non-padding char. + // Only canonical padding, so Canonical mode will work. + let suffixes = &[("AA==", 2), ("AAA=", 1), ("AAAA", 0)]; - // the different amounts of proper padding, w/ offset from end for the last non-padding char - let suffixes = [("/w==", 2), ("iYu=", 1), ("zzzz", 0)]; + for mode in pad_modes_allowing_padding() { + // We don't encode, so we don't care about encode padding. + let engine = E::standard_with_pad_mode(true, mode); - let prefix_quads_range = distributions::Uniform::from(0..=256); + decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine, + suffixes.as_slice(), + ); + } +} - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); +/// See [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes] +#[apply(all_engines)] +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix< + E: EngineWrapper, +>( + engine_wrapper: E, +) { + // Different amounts of padding, w/ offset from end for the last non-padding char, and + // non-canonical padding. + let suffixes = [ + ("AA==", 2), + ("AA=", 1), + ("AA", 0), + ("AAA=", 1), + ("AAA", 0), + ("AAAA", 0), + ]; - for _ in 0..100_000 { - for (suffix, offset) in suffixes.iter() { - let mut s = "ABCD".repeat(prefix_quads_range.sample(&mut rng)); - s.push_str(suffix); - let mut encoded = s.into_bytes(); + // We don't encode, so we don't care about encode padding. + // Decoding is indifferent so that we don't get caught by missing padding on the last quad + let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent); + + decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine, + suffixes.as_slice(), + ) +} + +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine: impl Engine, + suffixes: &[(&str, usize)], +) { + let mut rng = seeded_rng(); - // calculate a range to write padding into that leaves at least one non padding char - let last_non_padding_offset = encoded.len() - 1 - offset; + let prefix_quads_range = distributions::Uniform::from(0..=256); - // don't include last non padding char as it must stay not padding - let padding_end = rng.gen_range(0..last_non_padding_offset); + for _ in 0..100_000 { + for (suffix, suffix_offset) in suffixes.iter() { + let mut s = "AAAA".repeat(prefix_quads_range.sample(&mut rng)); + s.push_str(suffix); + let mut encoded = s.into_bytes(); - // don't use more than 100 bytes of padding, but also use shorter lengths when - // padding_end is near the start of the encoded data to avoid biasing to padding - // the entire prefix on short lengths - let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1)); - let padding_start = padding_end.saturating_sub(padding_len); + // calculate a range to write padding into that leaves at least one non padding char + let last_non_padding_offset = encoded.len() - 1 - suffix_offset; - encoded[padding_start..=padding_end].fill(PAD_BYTE); + // don't include last non padding char as it must stay not padding + let padding_end = rng.gen_range(0..last_non_padding_offset); - assert_eq!( - Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), - engine.decode(&encoded), - ); - } + // don't use more than 100 bytes of padding, but also use shorter lengths when + // padding_end is near the start of the encoded data to avoid biasing to padding + // the entire prefix on short lengths + let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1)); + let padding_start = padding_end.saturating_sub(padding_len); + + encoded[padding_start..=padding_end].fill(PAD_BYTE); + + // should still have non-padding before any final padding + assert_ne!(PAD_BYTE, encoded[last_non_padding_offset]); + assert_eq!( + Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), + engine.decode(&encoded), + "len: {}, input: {}", + encoded.len(), + String::from_utf8(encoded).unwrap() + ); } } } -/// Any amount of padding before final chunk that crosses over into final chunk with 2-4 bytes = +/// Any amount of padding before final chunk that crosses over into final chunk with 1-4 bytes = /// invalid byte at first pad byte. -/// From this and [decode_padding_starts_before_final_chunk_error_invalid_length] we know the -/// padding must start in the final chunk. -// DecoderReader pseudo-engine detects InvalidLastSymbol instead of InvalidLength because it -// can end a decode on the quad that happens to contain the start of the padding -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_starts_before_final_chunk_error_invalid_byte( +/// From this we know the padding must start in the final chunk. +#[apply(all_engines)] +fn decode_padding_starts_before_final_chunk_error_invalid_byte_at_first_pad( engine_wrapper: E, ) { let mut rng = seeded_rng(); // must have at least one prefix quad let prefix_quads_range = distributions::Uniform::from(1..256); - // excluding 1 since we don't care about invalid length in this test - let suffix_pad_len_range = distributions::Uniform::from(2..=4); - for mode in all_pad_modes() { + let suffix_pad_len_range = distributions::Uniform::from(1..=4); + // don't use no-padding mode, as the reader decode might decode a block that ends with + // valid padding, which should then be referenced when encountering the later invalid byte + for mode in pad_modes_allowing_padding() { // we don't encode so we don't care about encode padding let engine = E::standard_with_pad_mode(true, mode); for _ in 0..100_000 { let suffix_len = suffix_pad_len_range.sample(&mut rng); - let mut encoded = "ABCD" + // all 0 bits so we don't hit InvalidLastSymbol with the reader decoder + let mut encoded = "AAAA" .repeat(prefix_quads_range.sample(&mut rng)) .into_bytes(); encoded.resize(encoded.len() + suffix_len, PAD_BYTE); @@ -705,40 +774,6 @@ fn decode_padding_starts_before_final_chunk_error_invalid_byte } } -/// Any amount of padding before final chunk that crosses over into final chunk with 1 byte = -/// invalid length. -/// From this we know the padding must start in the final chunk. -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_starts_before_final_chunk_error_invalid_length( - engine_wrapper: E, -) { - let mut rng = seeded_rng(); - - // must have at least one prefix quad - let prefix_quads_range = distributions::Uniform::from(1..256); - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - for _ in 0..100_000 { - let mut encoded = "ABCD" - .repeat(prefix_quads_range.sample(&mut rng)) - .into_bytes(); - encoded.resize(encoded.len() + 1, PAD_BYTE); - - // amount of padding must be long enough to extend back from suffix into previous - // quads - let padding_len = rng.gen_range(1 + 1..encoded.len()); - // no non-padding after padding in this test, so padding goes to the end - let padding_start = encoded.len() - padding_len; - encoded[padding_start..].fill(PAD_BYTE); - - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); - } - } -} - /// 0-1 bytes of data before any amount of padding in final chunk = invalid byte, since padding /// is not valid data (consistent with error for pad bytes in earlier chunks). /// From this we know there must be 2-3 bytes of data before padding @@ -756,29 +791,22 @@ fn decode_too_little_data_before_padding_error_invalid_byte(en let suffix_data_len = suffix_data_len_range.sample(&mut rng); let prefix_quad_len = prefix_quads_range.sample(&mut rng); - // ensure there is a suffix quad - let min_padding = usize::from(suffix_data_len == 0); - // for all possible padding lengths - for padding_len in min_padding..=(4 - suffix_data_len) { + for padding_len in 1..=(4 - suffix_data_len) { let mut encoded = "ABCD".repeat(prefix_quad_len).into_bytes(); encoded.resize(encoded.len() + suffix_data_len, b'A'); encoded.resize(encoded.len() + padding_len, PAD_BYTE); - if suffix_data_len + padding_len == 1 { - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); - } else { - assert_eq!( - Err(DecodeError::InvalidByte( - prefix_quad_len * 4 + suffix_data_len, - PAD_BYTE, - )), - engine.decode(&encoded), - "suffix data len {} pad len {}", - suffix_data_len, - padding_len - ); - } + assert_eq!( + Err(DecodeError::InvalidByte( + prefix_quad_len * 4 + suffix_data_len, + PAD_BYTE, + )), + engine.decode(&encoded), + "suffix data len {} pad len {}", + suffix_data_len, + padding_len + ); } } } @@ -918,258 +946,64 @@ fn decode_pad_mode_indifferent_padding_accepts_anything(engine ); } -//this is a MAY in the rfc: https://tools.ietf.org/html/rfc4648#section-3.3 -// DecoderReader pseudo-engine finds the first padding, but doesn't report it as an error, -// because in the next decode it finds more padding, which is reported as InvalidByte, just -// with an offset at its position in the second decode, rather than being linked to the start -// of the padding that was first seen in the previous decode. -#[apply(all_engines_except_decoder_reader)] -fn decode_pad_byte_in_penultimate_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // leave room for at least one pad byte in penultimate quad - for num_valid_bytes_penultimate_quad in 0..4 { - // can't have 1 or it would be invalid length - for num_pad_bytes_in_final_quad in 2..=4 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - // varying amounts of padding in the penultimate quad - for _ in 0..num_valid_bytes_penultimate_quad { - s.push('A'); - } - // finish penultimate quad with padding - for _ in num_valid_bytes_penultimate_quad..4 { - s.push('='); - } - // and more padding in the final quad - for _ in 0..num_pad_bytes_in_final_quad { - s.push('='); - } - - // padding should be an invalid byte before the final quad. - // Could argue that the *next* padding byte (in the next quad) is technically the first - // erroneous one, but reporting that accurately is more complex and probably nobody cares - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + num_valid_bytes_penultimate_quad, - b'=', - ), - engine.decode(&s).unwrap_err(), - ); - } - } - } - } -} - -#[apply(all_engines)] -fn decode_bytes_after_padding_in_final_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // leave at least one byte in the quad for padding - for bytes_after_padding in 1..4 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - // every invalid padding position with a 3-byte final quad: 1 to 3 bytes after padding - for _ in 0..(3 - bytes_after_padding) { - s.push('A'); - } - s.push('='); - for _ in 0..bytes_after_padding { - s.push('A'); - } - - // First (and only) padding byte is invalid. - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + (3 - bytes_after_padding), - b'=' - ), - engine.decode(&s).unwrap_err() - ); - } - } - } -} - -#[apply(all_engines)] -fn decode_absurd_pad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("==Y=Wx===pY=2U====="); - - // first padding byte - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } -} - -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_too_much_padding_returns_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // add enough padding to ensure that we'll hit all decode stages at the different lengths - for pad_bytes in 1..=64 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - let padding: String = "=".repeat(pad_bytes); - s.push_str(&padding); - - if pad_bytes % 4 == 1 { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } else { - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } - } - } -} - -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_followed_by_non_padding_returns_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - for pad_bytes in 0..=32 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - let padding: String = "=".repeat(pad_bytes); - s.push_str(&padding); - s.push('E'); - - if pad_bytes % 4 == 0 { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } else { - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } - } - } -} - -#[apply(all_engines)] -fn decode_one_char_in_final_quad_with_padding_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("E="); - - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - - // more padding doesn't change the error - s.push('='); - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - - s.push('='); - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - } - } -} - -#[apply(all_engines)] -fn decode_too_few_symbols_in_final_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // <2 is invalid - for final_quad_symbols in 0..2 { - for padding_symbols in 0..=(4 - final_quad_symbols) { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - for _ in 0..final_quad_symbols { - s.push('A'); - } - for _ in 0..padding_symbols { - s.push('='); - } - - match final_quad_symbols + padding_symbols { - 0 => continue, - 1 => { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } - _ => { - // error reported at first padding byte - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + final_quad_symbols, - b'=', - ), - engine.decode(&s).unwrap_err() - ); - } - } - } - } - } - } -} - +/// 1 trailing byte that's not padding is detected as invalid byte even though there's padding +/// in the middle of the input. This is essentially mandating the eager check for 1 trailing byte +/// to catch the \n suffix case. // DecoderReader pseudo-engine can't handle DecodePaddingMode::RequireNone since it will decode // a complete quad with padding in it before encountering the stray byte that makes it an invalid // length #[apply(all_engines_except_decoder_reader)] -fn decode_invalid_trailing_bytes(engine_wrapper: E) { +fn decode_invalid_trailing_bytes_all_pad_modes_invalid_byte(engine_wrapper: E) { for mode in all_pad_modes() { do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode); } } #[apply(all_engines)] -fn decode_invalid_trailing_bytes_all_modes(engine_wrapper: E) { +fn decode_invalid_trailing_bytes_invalid_byte(engine_wrapper: E) { // excluding no padding mode because the DecoderWrapper pseudo-engine will fail with // InvalidPadding because it will decode the last complete quad with padding first for mode in pad_modes_allowing_padding() { do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode); } } +fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) { + for last_byte in [b'*', b'\n'] { + for num_prefix_quads in 0..256 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str("Cg=="); + let mut input = s.into_bytes(); + input.push(last_byte); + + // The case of trailing newlines is common enough to warrant a test for a good error + // message. + assert_eq!( + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + 4, + last_byte + )), + engine.decode(&input), + "mode: {:?}, input: {}", + mode, + String::from_utf8(input).unwrap() + ); + } + } +} +/// When there's 1 trailing byte, but it's padding, it's only InvalidByte if there isn't padding +/// earlier. #[apply(all_engines)] -fn decode_invalid_trailing_padding_as_invalid_length(engine_wrapper: E) { +fn decode_invalid_trailing_padding_as_invalid_byte_at_first_pad_byte( + engine_wrapper: E, +) { // excluding no padding mode because the DecoderWrapper pseudo-engine will fail with // InvalidPadding because it will decode the last complete quad with padding first for mode in pad_modes_allowing_padding() { - do_invalid_trailing_padding_as_invalid_length(E::standard_with_pad_mode(true, mode), mode); + do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + E::standard_with_pad_mode(true, mode), + mode, + ); } } @@ -1177,48 +1011,36 @@ fn decode_invalid_trailing_padding_as_invalid_length(engine_wr // a complete quad with padding in it before encountering the stray byte that makes it an invalid // length #[apply(all_engines_except_decoder_reader)] -fn decode_invalid_trailing_padding_as_invalid_length_all_modes( +fn decode_invalid_trailing_padding_as_invalid_byte_at_first_byte_all_modes( engine_wrapper: E, ) { for mode in all_pad_modes() { - do_invalid_trailing_padding_as_invalid_length(E::standard_with_pad_mode(true, mode), mode); + do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + E::standard_with_pad_mode(true, mode), + mode, + ); } } - -#[apply(all_engines)] -fn decode_wrong_length_error(engine_wrapper: E) { - let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent); - +fn do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + engine: impl Engine, + mode: DecodePaddingMode, +) { for num_prefix_quads in 0..256 { - // at least one token, otherwise it wouldn't be a final quad - for num_tokens_final_quad in 1..=4 { - for num_padding in 0..=(4 - num_tokens_final_quad) { - let mut s: String = "IIII".repeat(num_prefix_quads); - for _ in 0..num_tokens_final_quad { - s.push('g'); - } - for _ in 0..num_padding { - s.push('='); - } + for (suffix, pad_offset) in [("AA===", 2), ("AAA==", 3), ("AAAA=", 4)] { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str(suffix); - let res = engine.decode(&s); - if num_tokens_final_quad >= 2 { - assert!(res.is_ok()); - } else if num_tokens_final_quad == 1 && num_padding > 0 { - // = is invalid if it's too early - assert_eq!( - Err(DecodeError::InvalidByte( - num_prefix_quads * 4 + num_tokens_final_quad, - 61 - )), - res - ); - } else if num_padding > 2 { - assert_eq!(Err(DecodeError::InvalidPadding), res); - } else { - assert_eq!(Err(DecodeError::InvalidLength), res); - } - } + assert_eq!( + // pad after `g`, not the last one + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + pad_offset, + PAD_BYTE + )), + engine.decode(&s), + "mode: {:?}, input: {}", + mode, + s + ); } } } @@ -1248,14 +1070,23 @@ fn decode_into_slice_fits_in_precisely_sized_slice(engine_wrap assert_encode_sanity(&encoded_data, engine.config().encode_padding(), input_len); decode_buf.resize(input_len, 0); - // decode into the non-empty buf let decode_bytes_written = engine .decode_slice_unchecked(encoded_data.as_bytes(), &mut decode_buf[..]) .unwrap(); - assert_eq!(orig_data.len(), decode_bytes_written); assert_eq!(orig_data, decode_buf); + + // TODO + // same for checked variant + // decode_buf.clear(); + // decode_buf.resize(input_len, 0); + // // decode into the non-empty buf + // let decode_bytes_written = engine + // .decode_slice(encoded_data.as_bytes(), &mut decode_buf[..]) + // .unwrap(); + // assert_eq!(orig_data.len(), decode_bytes_written); + // assert_eq!(orig_data, decode_buf); } } @@ -1355,38 +1186,6 @@ fn estimate_via_u128_inflation(engine_wrapper: E) { }) } -fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) { - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("Cg==\n"); - - // The case of trailing newlines is common enough to warrant a test for a good error - // message. - assert_eq!( - Err(DecodeError::InvalidByte(num_prefix_quads * 4 + 4, b'\n')), - engine.decode(&s), - "mode: {:?}, input: {}", - mode, - s - ); - } -} - -fn do_invalid_trailing_padding_as_invalid_length(engine: impl Engine, mode: DecodePaddingMode) { - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("Cg==="); - - assert_eq!( - Err(DecodeError::InvalidLength), - engine.decode(&s), - "mode: {:?}, input: {}", - mode, - s - ); - } -} - /// Returns a tuple of the original data length, the encoded data length (just data), and the length including padding. /// /// Vecs provided should be empty. diff --git a/src/read/decoder.rs b/src/read/decoder.rs index b656ae3..125eeab 100644 --- a/src/read/decoder.rs +++ b/src/read/decoder.rs @@ -35,37 +35,39 @@ pub struct DecoderReader<'e, E: Engine, R: io::Read> { /// Where b64 data is read from inner: R, - // Holds b64 data read from the delegate reader. + /// Holds b64 data read from the delegate reader. b64_buffer: [u8; BUF_SIZE], - // The start of the pending buffered data in b64_buffer. + /// The start of the pending buffered data in `b64_buffer`. b64_offset: usize, - // The amount of buffered b64 data. + /// The amount of buffered b64 data after `b64_offset` in `b64_len`. b64_len: usize, - // Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a - // decoded chunk in to, we have to be able to hang on to a few decoded bytes. - // Technically we only need to hold 2 bytes but then we'd need a separate temporary buffer to - // decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest - // into here, which seems like a lot of complexity for 1 extra byte of storage. - decoded_buffer: [u8; DECODED_CHUNK_SIZE], - // index of start of decoded data + /// Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a + /// decoded chunk in to, we have to be able to hang on to a few decoded bytes. + /// Technically we only need to hold 2 bytes, but then we'd need a separate temporary buffer to + /// decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest + /// into here, which seems like a lot of complexity for 1 extra byte of storage. + decoded_chunk_buffer: [u8; DECODED_CHUNK_SIZE], + /// Index of start of decoded data in `decoded_chunk_buffer` decoded_offset: usize, - // length of decoded data + /// Length of decoded data after `decoded_offset` in `decoded_chunk_buffer` decoded_len: usize, - // used to provide accurate offsets in errors - total_b64_decoded: usize, - // offset of previously seen padding, if any + /// Input length consumed so far. + /// Used to provide accurate offsets in errors + input_consumed_len: usize, + /// offset of previously seen padding, if any padding_offset: Option, } +// exclude b64_buffer as it's uselessly large impl<'e, E: Engine, R: io::Read> fmt::Debug for DecoderReader<'e, E, R> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("DecoderReader") .field("b64_offset", &self.b64_offset) .field("b64_len", &self.b64_len) - .field("decoded_buffer", &self.decoded_buffer) + .field("decoded_chunk_buffer", &self.decoded_chunk_buffer) .field("decoded_offset", &self.decoded_offset) .field("decoded_len", &self.decoded_len) - .field("total_b64_decoded", &self.total_b64_decoded) + .field("input_consumed_len", &self.input_consumed_len) .field("padding_offset", &self.padding_offset) .finish() } @@ -80,10 +82,10 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { b64_buffer: [0; BUF_SIZE], b64_offset: 0, b64_len: 0, - decoded_buffer: [0; DECODED_CHUNK_SIZE], + decoded_chunk_buffer: [0; DECODED_CHUNK_SIZE], decoded_offset: 0, decoded_len: 0, - total_b64_decoded: 0, + input_consumed_len: 0, padding_offset: None, } } @@ -100,7 +102,7 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { debug_assert!(copy_len <= self.decoded_len); buf[..copy_len].copy_from_slice( - &self.decoded_buffer[self.decoded_offset..self.decoded_offset + copy_len], + &self.decoded_chunk_buffer[self.decoded_offset..self.decoded_offset + copy_len], ); self.decoded_offset += copy_len; @@ -146,18 +148,22 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { ) .map_err(|e| match e { DecodeError::InvalidByte(offset, byte) => { - // This can be incorrect, but not in a way that probably matters to anyone: - // if there was padding handled in a previous decode, and we are now getting - // InvalidByte due to more padding, we should arguably report InvalidByte with - // PAD_BYTE at the original padding position (`self.padding_offset`), but we - // don't have a good way to tie those two cases together, so instead we - // just report the invalid byte as if the previous padding, and its possibly - // related downgrade to a now invalid byte, didn't happen. - DecodeError::InvalidByte(self.total_b64_decoded + offset, byte) + match (byte, self.padding_offset) { + // if there was padding in a previous block of decoding that happened to + // be correct, and we now find more padding that happens to be incorrect, + // to be consistent with non-reader decodes, record the error at the first + // padding + (PAD_BYTE, Some(first_pad_offset)) => { + DecodeError::InvalidByte(first_pad_offset, PAD_BYTE) + } + _ => DecodeError::InvalidByte(self.input_consumed_len + offset, byte), + } + } + DecodeError::InvalidLength(len) => { + DecodeError::InvalidLength(self.input_consumed_len + len) } - DecodeError::InvalidLength => DecodeError::InvalidLength, DecodeError::InvalidLastSymbol(offset, byte) => { - DecodeError::InvalidLastSymbol(self.total_b64_decoded + offset, byte) + DecodeError::InvalidLastSymbol(self.input_consumed_len + offset, byte) } DecodeError::InvalidPadding => DecodeError::InvalidPadding, }) @@ -176,8 +182,8 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { self.padding_offset = self.padding_offset.or(decode_metadata .padding_offset - .map(|offset| self.total_b64_decoded + offset)); - self.total_b64_decoded += b64_len_to_decode; + .map(|offset| self.input_consumed_len + offset)); + self.input_consumed_len += b64_len_to_decode; self.b64_offset += b64_len_to_decode; self.b64_len -= b64_len_to_decode; @@ -283,7 +289,7 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> { let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE); let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?; - self.decoded_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]); + self.decoded_chunk_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]); self.decoded_offset = 0; self.decoded_len = decoded; From 9e9c7abe65fed78c35a1e94e11446d66ff118c25 Mon Sep 17 00:00:00 2001 From: Marshall Pierce <575695+marshallpierce@users.noreply.github.com> Date: Fri, 1 Mar 2024 16:52:32 -0700 Subject: [PATCH 4/6] Engine::internal_decode now returns DecodeSliceError Implementations must now precisely, not conservatively, return an error when the output length is too small. --- benches/benchmarks.rs | 3 +- src/decode.rs | 4 +- src/engine/general_purpose/decode.rs | 75 +++++++++++------ src/engine/general_purpose/decode_suffix.rs | 59 ++++++------- src/engine/general_purpose/mod.rs | 6 +- src/engine/mod.rs | 41 ++++++---- src/engine/naive.rs | 70 ++++++---------- src/engine/tests.rs | 91 +++++++++++++++++---- src/read/decoder.rs | 47 +++++++---- 9 files changed, 237 insertions(+), 159 deletions(-) diff --git a/benches/benchmarks.rs b/benches/benchmarks.rs index 802c8cc..8f04185 100644 --- a/benches/benchmarks.rs +++ b/benches/benchmarks.rs @@ -102,9 +102,8 @@ fn do_encode_bench_slice(b: &mut Bencher, &size: &usize) { fn do_encode_bench_stream(b: &mut Bencher, &size: &usize) { let mut v: Vec = Vec::with_capacity(size); fill(&mut v); - let mut buf = Vec::new(); + let mut buf = Vec::with_capacity(size * 2); - buf.reserve(size * 2); b.iter(|| { buf.clear(); let mut stream_enc = write::EncoderWriter::new(&mut buf, &STANDARD); diff --git a/src/decode.rs b/src/decode.rs index 0f66c74..d042b09 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -52,9 +52,7 @@ impl error::Error for DecodeError {} pub enum DecodeSliceError { /// A [DecodeError] occurred DecodeError(DecodeError), - /// The provided slice _may_ be too small. - /// - /// The check is conservative (assumes the last triplet of output bytes will all be needed). + /// The provided slice is too small. OutputSliceTooSmall, } diff --git a/src/engine/general_purpose/decode.rs b/src/engine/general_purpose/decode.rs index 31c289e..98ce043 100644 --- a/src/engine/general_purpose/decode.rs +++ b/src/engine/general_purpose/decode.rs @@ -1,12 +1,13 @@ use crate::{ engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode}, - DecodeError, PAD_BYTE, + DecodeError, DecodeSliceError, PAD_BYTE, }; #[doc(hidden)] pub struct GeneralPurposeEstimate { + /// input len % 4 rem: usize, - conservative_len: usize, + conservative_decoded_len: usize, } impl GeneralPurposeEstimate { @@ -14,14 +15,14 @@ impl GeneralPurposeEstimate { let rem = encoded_len % 4; Self { rem, - conservative_len: (encoded_len / 4 + (rem > 0) as usize) * 3, + conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3, } } } impl DecodeEstimate for GeneralPurposeEstimate { fn decoded_len_estimate(&self) -> usize { - self.conservative_len + self.conservative_decoded_len } } @@ -38,25 +39,9 @@ pub(crate) fn decode_helper( decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, -) -> Result { - // detect a trailing invalid byte, like a newline, as a user convenience - if estimate.rem == 1 { - let last_byte = input[input.len() - 1]; - // exclude pad bytes; might be part of padding that extends from earlier in the input - if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE { - return Err(DecodeError::InvalidByte(input.len() - 1, last_byte)); - } - } - - // skip last quad, even if it's complete, as it may have padding - let input_complete_nonterminal_quads_len = input - .len() - .saturating_sub(estimate.rem) - // if rem was 0, subtract 4 to avoid padding - .saturating_sub((estimate.rem == 0) as usize * 4); - debug_assert!( - input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len)) - ); +) -> Result { + let input_complete_nonterminal_quads_len = + complete_quads_len(input, estimate.rem, output.len(), decode_table)?; const UNROLLED_INPUT_CHUNK_SIZE: usize = 32; const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3; @@ -135,6 +120,48 @@ pub(crate) fn decode_helper( ) } +/// Returns the length of complete quads, except for the last one, even if it is complete. +/// +/// Returns an error if the output len is not big enough for decoding those complete quads, or if +/// the input % 4 == 1, and that last byte is an invalid value other than a pad byte. +/// +/// - `input` is the base64 input +/// - `input_len_rem` is input len % 4 +/// - `output_len` is the length of the output slice +pub(crate) fn complete_quads_len( + input: &[u8], + input_len_rem: usize, + output_len: usize, + decode_table: &[u8; 256], +) -> Result { + debug_assert!(input.len() % 4 == input_len_rem); + + // detect a trailing invalid byte, like a newline, as a user convenience + if input_len_rem == 1 { + let last_byte = input[input.len() - 1]; + // exclude pad bytes; might be part of padding that extends from earlier in the input + if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE { + return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into()); + } + }; + + // skip last quad, even if it's complete, as it may have padding + let input_complete_nonterminal_quads_len = input + .len() + .saturating_sub(input_len_rem) + // if rem was 0, subtract 4 to avoid padding + .saturating_sub((input_len_rem == 0) as usize * 4); + debug_assert!( + input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len)) + ); + + // check that everything except the last quad handled by decode_suffix will fit + if output_len < input_complete_nonterminal_quads_len / 4 * 3 { + return Err(DecodeSliceError::OutputSliceTooSmall); + }; + Ok(input_complete_nonterminal_quads_len) +} + /// Decode 8 bytes of input into 6 bytes of output. /// /// `input` is the 8 bytes to decode. @@ -321,7 +348,7 @@ mod tests { let len_128 = encoded_len as u128; let estimate = GeneralPurposeEstimate::new(encoded_len); - assert_eq!((len_128 + 3) / 4 * 3, estimate.conservative_len as u128); + assert_eq!((len_128 + 3) / 4 * 3, estimate.conservative_decoded_len as u128); }) } } diff --git a/src/engine/general_purpose/decode_suffix.rs b/src/engine/general_purpose/decode_suffix.rs index 3d52ae5..02aaf51 100644 --- a/src/engine/general_purpose/decode_suffix.rs +++ b/src/engine/general_purpose/decode_suffix.rs @@ -1,6 +1,6 @@ use crate::{ engine::{general_purpose::INVALID_VALUE, DecodeMetadata, DecodePaddingMode}, - DecodeError, PAD_BYTE, + DecodeError, DecodeSliceError, PAD_BYTE, }; /// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided @@ -16,11 +16,11 @@ pub(crate) fn decode_suffix( decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, -) -> Result { +) -> Result { debug_assert!((input.len() - input_index) <= 4); - // Decode any leftovers that might not be a complete input chunk of 8 bytes. - // Use a u64 as a stack-resident 8 byte buffer. + // Decode any leftovers that might not be a complete input chunk of 4 bytes. + // Use a u32 as a stack-resident 4 byte buffer. let mut morsels_in_leftover = 0; let mut padding_bytes_count = 0; // offset from input_index @@ -44,22 +44,14 @@ pub(crate) fn decode_suffix( // may be treated as an error condition. if leftover_index < 2 { - // Check for case #2. - let bad_padding_index = input_index - + if padding_bytes_count > 0 { - // If we've already seen padding, report the first padding index. - // This is to be consistent with the normal decode logic: it will report an - // error on the first padding character (since it doesn't expect to see - // anything but actual encoded data). - // This could only happen if the padding started in the previous quad since - // otherwise this case would have been hit at i == 4 if it was the same - // quad. - first_padding_offset - } else { - // haven't seen padding before, just use where we are now - leftover_index - }; - return Err(DecodeError::InvalidByte(bad_padding_index, b)); + // Check for error #2. + // Either the previous byte was padding, in which case we would have already hit + // this case, or it wasn't, in which case this is the first such error. + debug_assert!( + leftover_index == 0 || (leftover_index == 1 && padding_bytes_count == 0) + ); + let bad_padding_index = input_index + leftover_index; + return Err(DecodeError::InvalidByte(bad_padding_index, b).into()); } if padding_bytes_count == 0 { @@ -75,10 +67,9 @@ pub(crate) fn decode_suffix( // non-suffix '=' in trailing chunk either. Report error as first // erroneous padding. if padding_bytes_count > 0 { - return Err(DecodeError::InvalidByte( - input_index + first_padding_offset, - PAD_BYTE, - )); + return Err( + DecodeError::InvalidByte(input_index + first_padding_offset, PAD_BYTE).into(), + ); } last_symbol = b; @@ -87,7 +78,7 @@ pub(crate) fn decode_suffix( // Pack the leftovers from left to right. let morsel = decode_table[b as usize]; if morsel == INVALID_VALUE { - return Err(DecodeError::InvalidByte(input_index + leftover_index, b)); + return Err(DecodeError::InvalidByte(input_index + leftover_index, b).into()); } morsels[morsels_in_leftover] = morsel; @@ -97,9 +88,7 @@ pub(crate) fn decode_suffix( // If there was 1 trailing byte, and it was valid, and we got to this point without hitting // an invalid byte, now we can report invalid length if !input.is_empty() && morsels_in_leftover < 2 { - return Err(DecodeError::InvalidLength( - input_index + morsels_in_leftover, - )); + return Err(DecodeError::InvalidLength(input_index + morsels_in_leftover).into()); } match padding_mode { @@ -107,14 +96,14 @@ pub(crate) fn decode_suffix( DecodePaddingMode::RequireCanonical => { // allow empty input if (padding_bytes_count + morsels_in_leftover) % 4 != 0 { - return Err(DecodeError::InvalidPadding); + return Err(DecodeError::InvalidPadding.into()); } } DecodePaddingMode::RequireNone => { if padding_bytes_count > 0 { // check at the end to make sure we let the cases of padding that should be InvalidByte // get hit - return Err(DecodeError::InvalidPadding); + return Err(DecodeError::InvalidPadding.into()); } } } @@ -127,7 +116,7 @@ pub(crate) fn decode_suffix( // bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a // mask based on how many bits are used for just the canonical encoding, and optionally // error if any other bits are set. In the example of one encoded byte -> 2 symbols, - // 2 symbols can technically encode 12 bits, but the last 4 are non canonical, and + // 2 symbols can technically encode 12 bits, but the last 4 are non-canonical, and // useless since there are no more symbols to provide the necessary 4 additional bits // to finish the second original byte. @@ -147,7 +136,8 @@ pub(crate) fn decode_suffix( return Err(DecodeError::InvalidLastSymbol( input_index + morsels_in_leftover - 1, last_symbol, - )); + ) + .into()); } // Strangely, this approach benchmarks better than writing bytes one at a time, @@ -155,8 +145,9 @@ pub(crate) fn decode_suffix( for _ in 0..leftover_bytes_to_append { let hi_byte = (leftover_num >> 24) as u8; leftover_num <<= 8; - // TODO use checked writes - output[output_index] = hi_byte; + *output + .get_mut(output_index) + .ok_or(DecodeSliceError::OutputSliceTooSmall)? = hi_byte; output_index += 1; } diff --git a/src/engine/general_purpose/mod.rs b/src/engine/general_purpose/mod.rs index e0227f3..6fe9580 100644 --- a/src/engine/general_purpose/mod.rs +++ b/src/engine/general_purpose/mod.rs @@ -3,11 +3,11 @@ use crate::{ alphabet, alphabet::Alphabet, engine::{Config, DecodeMetadata, DecodePaddingMode}, - DecodeError, + DecodeSliceError, }; use core::convert::TryInto; -mod decode; +pub(crate) mod decode; pub(crate) mod decode_suffix; pub use decode::GeneralPurposeEstimate; @@ -173,7 +173,7 @@ impl super::Engine for GeneralPurpose { input: &[u8], output: &mut [u8], estimate: Self::DecodeEstimate, - ) -> Result { + ) -> Result { decode::decode_helper( input, estimate, diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 16c05d7..77dcd14 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -83,17 +83,13 @@ pub trait Engine: Send + Sync { /// /// Non-canonical trailing bits in the final tokens or non-canonical padding must be reported as /// errors unless the engine is configured otherwise. - /// - /// # Panics - /// - /// Panics if `output` is too small. #[doc(hidden)] fn internal_decode( &self, input: &[u8], output: &mut [u8], decode_estimate: Self::DecodeEstimate, - ) -> Result; + ) -> Result; /// Returns the config for this engine. fn config(&self) -> &Self::Config; @@ -253,7 +249,13 @@ pub trait Engine: Send + Sync { let mut buffer = vec![0; estimate.decoded_len_estimate()]; let bytes_written = engine - .internal_decode(input_bytes, &mut buffer, estimate)? + .internal_decode(input_bytes, &mut buffer, estimate) + .map_err(|e| match e { + DecodeSliceError::DecodeError(e) => e, + DecodeSliceError::OutputSliceTooSmall => { + unreachable!("Vec is sized conservatively") + } + })? .decoded_len; buffer.truncate(bytes_written); @@ -318,7 +320,13 @@ pub trait Engine: Send + Sync { let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..]; let bytes_written = engine - .internal_decode(input_bytes, buffer_slice, estimate)? + .internal_decode(input_bytes, buffer_slice, estimate) + .map_err(|e| match e { + DecodeSliceError::DecodeError(e) => e, + DecodeSliceError::OutputSliceTooSmall => { + unreachable!("Vec is sized conservatively") + } + })? .decoded_len; buffer.truncate(starting_output_len + bytes_written); @@ -354,15 +362,12 @@ pub trait Engine: Send + Sync { where E: Engine + ?Sized, { - let estimate = engine.internal_decoded_len_estimate(input_bytes.len()); - - if output.len() < estimate.decoded_len_estimate() { - return Err(DecodeSliceError::OutputSliceTooSmall); - } - engine - .internal_decode(input_bytes, output, estimate) - .map_err(|e| e.into()) + .internal_decode( + input_bytes, + output, + engine.internal_decoded_len_estimate(input_bytes.len()), + ) .map(|dm| dm.decoded_len) } @@ -400,6 +405,12 @@ pub trait Engine: Send + Sync { engine.internal_decoded_len_estimate(input_bytes.len()), ) .map(|dm| dm.decoded_len) + .map_err(|e| match e { + DecodeSliceError::DecodeError(e) => e, + DecodeSliceError::OutputSliceTooSmall => { + panic!("Output slice is too small") + } + }) } inner(self, input.as_ref(), output) diff --git a/src/engine/naive.rs b/src/engine/naive.rs index 2546a6f..15c07cc 100644 --- a/src/engine/naive.rs +++ b/src/engine/naive.rs @@ -4,7 +4,7 @@ use crate::{ general_purpose::{self, decode_table, encode_table}, Config, DecodeEstimate, DecodeMetadata, DecodePaddingMode, Engine, }, - DecodeError, PAD_BYTE, + DecodeError, DecodeSliceError, }; use std::ops::{BitAnd, BitOr, Shl, Shr}; @@ -111,60 +111,36 @@ impl Engine for Naive { input: &[u8], output: &mut [u8], estimate: Self::DecodeEstimate, - ) -> Result { - if estimate.rem == 1 { - // trailing whitespace is so common that it's worth it to check the last byte to - // possibly return a better error message - let last_byte = input[input.len() - 1]; - if last_byte != PAD_BYTE - && self.decode_table[usize::from(last_byte)] == general_purpose::INVALID_VALUE - { - return Err(DecodeError::InvalidByte(input.len() - 1, last_byte)); - } - } + ) -> Result { + let complete_nonterminal_quads_len = + general_purpose::decode::complete_quads_len(input, estimate.rem, output.len(), &self.decode_table)?; - let mut input_index = 0_usize; - let mut output_index = 0_usize; const BOTTOM_BYTE: u32 = 0xFF; - // can only use the main loop on non-trailing chunks - if input.len() > Self::DECODE_INPUT_CHUNK_SIZE { - // skip the last chunk, whether it's partial or full, since it might - // have padding, and start at the beginning of the chunk before that - let last_complete_chunk_start_index = estimate.complete_chunk_len - - if estimate.rem == 0 { - // Trailing chunk is also full chunk, so there must be at least 2 chunks, and - // this won't underflow - Self::DECODE_INPUT_CHUNK_SIZE * 2 - } else { - // Trailing chunk is partial, so it's already excluded in - // complete_chunk_len - Self::DECODE_INPUT_CHUNK_SIZE - }; - - while input_index <= last_complete_chunk_start_index { - let chunk = &input[input_index..input_index + Self::DECODE_INPUT_CHUNK_SIZE]; - let decoded_int: u32 = self.decode_byte_into_u32(input_index, chunk[0])?.shl(18) - | self - .decode_byte_into_u32(input_index + 1, chunk[1])? - .shl(12) - | self.decode_byte_into_u32(input_index + 2, chunk[2])?.shl(6) - | self.decode_byte_into_u32(input_index + 3, chunk[3])?; - - output[output_index] = decoded_int.shr(16_u8).bitand(BOTTOM_BYTE) as u8; - output[output_index + 1] = decoded_int.shr(8_u8).bitand(BOTTOM_BYTE) as u8; - output[output_index + 2] = decoded_int.bitand(BOTTOM_BYTE) as u8; - - input_index += Self::DECODE_INPUT_CHUNK_SIZE; - output_index += 3; - } + for (chunk_index, chunk) in input[..complete_nonterminal_quads_len] + .chunks_exact(4) + .enumerate() + { + let input_index = chunk_index * 4; + let output_index = chunk_index * 3; + + let decoded_int: u32 = self.decode_byte_into_u32(input_index, chunk[0])?.shl(18) + | self + .decode_byte_into_u32(input_index + 1, chunk[1])? + .shl(12) + | self.decode_byte_into_u32(input_index + 2, chunk[2])?.shl(6) + | self.decode_byte_into_u32(input_index + 3, chunk[3])?; + + output[output_index] = decoded_int.shr(16_u8).bitand(BOTTOM_BYTE) as u8; + output[output_index + 1] = decoded_int.shr(8_u8).bitand(BOTTOM_BYTE) as u8; + output[output_index + 2] = decoded_int.bitand(BOTTOM_BYTE) as u8; } general_purpose::decode_suffix::decode_suffix( input, - input_index, + complete_nonterminal_quads_len, output, - output_index, + complete_nonterminal_quads_len / 4 * 3, &self.decode_table, self.config.decode_allow_trailing_bits, self.config.decode_padding_mode, diff --git a/src/engine/tests.rs b/src/engine/tests.rs index b73f108..72bbf4b 100644 --- a/src/engine/tests.rs +++ b/src/engine/tests.rs @@ -19,7 +19,7 @@ use crate::{ }, read::DecoderReader, tests::{assert_encode_sanity, random_alphabet, random_config}, - DecodeError, PAD_BYTE, + DecodeError, DecodeSliceError, PAD_BYTE, }; // the case::foo syntax includes the "foo" in the generated test method names @@ -803,7 +803,8 @@ fn decode_too_little_data_before_padding_error_invalid_byte(en PAD_BYTE, )), engine.decode(&encoded), - "suffix data len {} pad len {}", + "input {} suffix data len {} pad len {}", + String::from_utf8(encoded).unwrap(), suffix_data_len, padding_len ); @@ -1077,16 +1078,15 @@ fn decode_into_slice_fits_in_precisely_sized_slice(engine_wrap assert_eq!(orig_data.len(), decode_bytes_written); assert_eq!(orig_data, decode_buf); - // TODO // same for checked variant - // decode_buf.clear(); - // decode_buf.resize(input_len, 0); - // // decode into the non-empty buf - // let decode_bytes_written = engine - // .decode_slice(encoded_data.as_bytes(), &mut decode_buf[..]) - // .unwrap(); - // assert_eq!(orig_data.len(), decode_bytes_written); - // assert_eq!(orig_data, decode_buf); + decode_buf.clear(); + decode_buf.resize(input_len, 0); + // decode into the non-empty buf + let decode_bytes_written = engine + .decode_slice(encoded_data.as_bytes(), &mut decode_buf[..]) + .unwrap(); + assert_eq!(orig_data.len(), decode_bytes_written); + assert_eq!(orig_data, decode_buf); } } @@ -1118,7 +1118,10 @@ fn inner_decode_reports_padding_position(engine_wrapper: E) { if pad_position % 4 < 2 { // impossible padding assert_eq!( - Err(DecodeError::InvalidByte(pad_position, PAD_BYTE)), + Err(DecodeSliceError::DecodeError(DecodeError::InvalidByte( + pad_position, + PAD_BYTE + ))), decode_res ); } else { @@ -1186,6 +1189,63 @@ fn estimate_via_u128_inflation(engine_wrapper: E) { }) } +#[apply(all_engines)] +fn decode_slice_checked_fails_gracefully_at_all_output_lengths( + engine_wrapper: E, +) { + let mut rng = seeded_rng(); + for original_len in 0..1000 { + let mut original = vec![0; original_len]; + rng.fill(&mut original[..]); + + for mode in all_pad_modes() { + let engine = E::standard_with_pad_mode( + match mode { + DecodePaddingMode::Indifferent | DecodePaddingMode::RequireCanonical => true, + DecodePaddingMode::RequireNone => false, + }, + mode, + ); + + let encoded = engine.encode(&original); + let mut decode_buf = Vec::with_capacity(original_len); + for decode_buf_len in 0..original_len { + decode_buf.resize(decode_buf_len, 0); + assert_eq!( + DecodeSliceError::OutputSliceTooSmall, + engine + .decode_slice(&encoded, &mut decode_buf[..]) + .unwrap_err(), + "original len: {}, encoded len: {}, buf len: {}, mode: {:?}", + original_len, + encoded.len(), + decode_buf_len, + mode + ); + // internal method works the same + assert_eq!( + DecodeSliceError::OutputSliceTooSmall, + engine + .internal_decode( + encoded.as_bytes(), + &mut decode_buf[..], + engine.internal_decoded_len_estimate(encoded.len()) + ) + .unwrap_err() + ); + } + + decode_buf.resize(original_len, 0); + rng.fill(&mut decode_buf[..]); + assert_eq!( + original_len, + engine.decode_slice(&encoded, &mut decode_buf[..]).unwrap() + ); + assert_eq!(original, decode_buf); + } + } +} + /// Returns a tuple of the original data length, the encoded data length (just data), and the length including padding. /// /// Vecs provided should be empty. @@ -1346,7 +1406,7 @@ impl EngineWrapper for NaiveWrapper { naive::Naive::new( &STANDARD, naive::NaiveConfig { - encode_padding: false, + encode_padding: encode_pad, decode_allow_trailing_bits: false, decode_padding_mode: decode_pad_mode, }, @@ -1415,7 +1475,7 @@ impl Engine for DecoderReaderEngine { input: &[u8], output: &mut [u8], decode_estimate: Self::DecodeEstimate, - ) -> Result { + ) -> Result { let mut reader = DecoderReader::new(input, &self.engine); let mut buf = vec![0; input.len()]; // to avoid effects like not detecting invalid length due to progressively growing @@ -1434,6 +1494,9 @@ impl Engine for DecoderReaderEngine { .and_then(|inner| inner.downcast::().ok()) .unwrap() })?; + if output.len() < buf.len() { + return Err(DecodeSliceError::OutputSliceTooSmall); + } output[..buf.len()].copy_from_slice(&buf); Ok(DecodeMetadata::new( buf.len(), diff --git a/src/read/decoder.rs b/src/read/decoder.rs index 125eeab..781f6f8 100644 --- a/src/read/decoder.rs +++ b/src/read/decoder.rs @@ -1,4 +1,4 @@ -use crate::{engine::Engine, DecodeError, PAD_BYTE}; +use crate::{engine::Engine, DecodeError, DecodeSliceError, PAD_BYTE}; use std::{cmp, fmt, io}; // This should be large, but it has to fit on the stack. @@ -133,6 +133,10 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { /// caller's responsibility to choose the number of b64 bytes to decode correctly. /// /// Returns a Result with the number of decoded bytes written to `buf`. + /// + /// # Panics + /// + /// panics if `buf` is too small fn decode_to_buf(&mut self, b64_len_to_decode: usize, buf: &mut [u8]) -> io::Result { debug_assert!(self.b64_len >= b64_len_to_decode); debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE); @@ -146,26 +150,35 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { buf, self.engine.internal_decoded_len_estimate(b64_len_to_decode), ) - .map_err(|e| match e { - DecodeError::InvalidByte(offset, byte) => { - match (byte, self.padding_offset) { - // if there was padding in a previous block of decoding that happened to - // be correct, and we now find more padding that happens to be incorrect, - // to be consistent with non-reader decodes, record the error at the first - // padding - (PAD_BYTE, Some(first_pad_offset)) => { - DecodeError::InvalidByte(first_pad_offset, PAD_BYTE) + .map_err(|dse| match dse { + DecodeSliceError::DecodeError(de) => { + match de { + DecodeError::InvalidByte(offset, byte) => { + match (byte, self.padding_offset) { + // if there was padding in a previous block of decoding that happened to + // be correct, and we now find more padding that happens to be incorrect, + // to be consistent with non-reader decodes, record the error at the first + // padding + (PAD_BYTE, Some(first_pad_offset)) => { + DecodeError::InvalidByte(first_pad_offset, PAD_BYTE) + } + _ => { + DecodeError::InvalidByte(self.input_consumed_len + offset, byte) + } + } + } + DecodeError::InvalidLength(len) => { + DecodeError::InvalidLength(self.input_consumed_len + len) } - _ => DecodeError::InvalidByte(self.input_consumed_len + offset, byte), + DecodeError::InvalidLastSymbol(offset, byte) => { + DecodeError::InvalidLastSymbol(self.input_consumed_len + offset, byte) + } + DecodeError::InvalidPadding => DecodeError::InvalidPadding, } } - DecodeError::InvalidLength(len) => { - DecodeError::InvalidLength(self.input_consumed_len + len) - } - DecodeError::InvalidLastSymbol(offset, byte) => { - DecodeError::InvalidLastSymbol(self.input_consumed_len + offset, byte) + DecodeSliceError::OutputSliceTooSmall => { + unreachable!("buf is sized correctly in calling code") } - DecodeError::InvalidPadding => DecodeError::InvalidPadding, }) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; From 2b91084a31ad11624acd81e06455ba0cbd21d4a8 Mon Sep 17 00:00:00 2001 From: Marshall Pierce <575695+marshallpierce@users.noreply.github.com> Date: Fri, 1 Mar 2024 18:07:50 -0700 Subject: [PATCH 5/6] Add some tests to boost coverage The logic isn't important, but it helps make actual coverage gaps more visible. --- src/decode.rs | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/decode.rs b/src/decode.rs index d042b09..6df8aba 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -340,3 +340,47 @@ mod tests { } } } + +#[allow(deprecated)] +#[cfg(test)] +mod coverage_gaming { + use super::*; + use std::error::Error; + + #[test] + fn decode_error() { + let _ = format!("{:?}", DecodeError::InvalidPadding.clone()); + let _ = format!( + "{} {} {} {}", + DecodeError::InvalidByte(0, 0), + DecodeError::InvalidLength(0), + DecodeError::InvalidLastSymbol(0, 0), + DecodeError::InvalidPadding, + ); + } + + #[test] + fn decode_slice_error() { + let _ = format!("{:?}", DecodeSliceError::OutputSliceTooSmall.clone()); + let _ = format!( + "{} {}", + DecodeSliceError::OutputSliceTooSmall, + DecodeSliceError::DecodeError(DecodeError::InvalidPadding) + ); + let _ = DecodeSliceError::OutputSliceTooSmall.source(); + let _ = DecodeSliceError::DecodeError(DecodeError::InvalidPadding).source(); + } + + #[test] + fn deprecated_fns() { + let _ = decode(""); + let _ = decode_engine("", &crate::prelude::BASE64_STANDARD); + let _ = decode_engine_vec("", &mut Vec::new(), &crate::prelude::BASE64_STANDARD); + let _ = decode_engine_slice("", &mut [], &crate::prelude::BASE64_STANDARD); + } + + #[test] + fn decoded_len_est() { + assert_eq!(3, decoded_len_estimate(4)); + } +} From efb6c006c75ddbe60c084c2e3e0e084cd18b0122 Mon Sep 17 00:00:00 2001 From: Marshall Pierce <575695+marshallpierce@users.noreply.github.com> Date: Fri, 1 Mar 2024 18:14:38 -0700 Subject: [PATCH 6/6] Release notes --- Cargo.toml | 2 +- RELEASE-NOTES.md | 6 ++++++ src/engine/general_purpose/decode.rs | 5 ++++- src/engine/naive.rs | 8 ++++++-- src/lib.rs | 2 +- 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4db5d26..c2670d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "base64" -version = "0.21.7" +version = "0.22.0" authors = ["Alice Maz ", "Marshall Pierce "] description = "encodes and decodes base64 as bytes or utf8" repository = "https://github.com/marshallpierce/rust-base64" diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 0031215..46e281e 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -1,3 +1,9 @@ +# 0.22.0 + +- `DecodeSliceError::OutputSliceTooSmall` is now conservative rather than precise. That is, the error will only occur if the decoded output _cannot_ fit, meaning that `Engine::decode_slice` can now be used with exactly-sized output slices. As part of this, `Engine::internal_decode` now returns `DecodeSliceError` instead of `DecodeError`, but that is not expected to affect any external callers. +- `DecodeError::InvalidLength` now refers specifically to the _number of valid symbols_ being invalid (i.e. `len % 4 == 1`), rather than just the number of input bytes. This avoids confusing scenarios when based on interpretation you could make a case for either `InvalidLength` or `InvalidByte` being appropriate. +- Decoding is somewhat faster (5-10%) + # 0.21.7 - Support getting an alphabet's contents as a str via `Alphabet::as_str()` diff --git a/src/engine/general_purpose/decode.rs b/src/engine/general_purpose/decode.rs index 98ce043..b55d3fc 100644 --- a/src/engine/general_purpose/decode.rs +++ b/src/engine/general_purpose/decode.rs @@ -348,7 +348,10 @@ mod tests { let len_128 = encoded_len as u128; let estimate = GeneralPurposeEstimate::new(encoded_len); - assert_eq!((len_128 + 3) / 4 * 3, estimate.conservative_decoded_len as u128); + assert_eq!( + (len_128 + 3) / 4 * 3, + estimate.conservative_decoded_len as u128 + ); }) } } diff --git a/src/engine/naive.rs b/src/engine/naive.rs index 15c07cc..af509bf 100644 --- a/src/engine/naive.rs +++ b/src/engine/naive.rs @@ -112,8 +112,12 @@ impl Engine for Naive { output: &mut [u8], estimate: Self::DecodeEstimate, ) -> Result { - let complete_nonterminal_quads_len = - general_purpose::decode::complete_quads_len(input, estimate.rem, output.len(), &self.decode_table)?; + let complete_nonterminal_quads_len = general_purpose::decode::complete_quads_len( + input, + estimate.rem, + output.len(), + &self.decode_table, + )?; const BOTTOM_BYTE: u32 = 0xFF; diff --git a/src/lib.rs b/src/lib.rs index 6b5cccb..579a722 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -228,7 +228,7 @@ unused_extern_crates, unused_import_braces, unused_results, - variant_size_differences, + variant_size_differences )] #![forbid(unsafe_code)] // Allow globally until https://github.com/rust-lang/rust-clippy/issues/8768 is resolved.