diff --git a/Cargo.toml b/Cargo.toml index 4db5d26..c2670d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "base64" -version = "0.21.7" +version = "0.22.0" authors = ["Alice Maz ", "Marshall Pierce "] description = "encodes and decodes base64 as bytes or utf8" repository = "https://github.com/marshallpierce/rust-base64" diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 0031215..46e281e 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -1,3 +1,9 @@ +# 0.22.0 + +- `DecodeSliceError::OutputSliceTooSmall` is now conservative rather than precise. That is, the error will only occur if the decoded output _cannot_ fit, meaning that `Engine::decode_slice` can now be used with exactly-sized output slices. As part of this, `Engine::internal_decode` now returns `DecodeSliceError` instead of `DecodeError`, but that is not expected to affect any external callers. +- `DecodeError::InvalidLength` now refers specifically to the _number of valid symbols_ being invalid (i.e. `len % 4 == 1`), rather than just the number of input bytes. This avoids confusing scenarios when based on interpretation you could make a case for either `InvalidLength` or `InvalidByte` being appropriate. +- Decoding is somewhat faster (5-10%) + # 0.21.7 - Support getting an alphabet's contents as a str via `Alphabet::as_str()` diff --git a/benches/benchmarks.rs b/benches/benchmarks.rs index 802c8cc..8f04185 100644 --- a/benches/benchmarks.rs +++ b/benches/benchmarks.rs @@ -102,9 +102,8 @@ fn do_encode_bench_slice(b: &mut Bencher, &size: &usize) { fn do_encode_bench_stream(b: &mut Bencher, &size: &usize) { let mut v: Vec = Vec::with_capacity(size); fill(&mut v); - let mut buf = Vec::new(); + let mut buf = Vec::with_capacity(size * 2); - buf.reserve(size * 2); b.iter(|| { buf.clear(); let mut stream_enc = write::EncoderWriter::new(&mut buf, &STANDARD); diff --git a/src/decode.rs b/src/decode.rs index 5230fd3..6df8aba 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -9,18 +9,20 @@ use std::error; #[derive(Clone, Debug, PartialEq, Eq)] pub enum DecodeError { /// An invalid byte was found in the input. The offset and offending byte are provided. - /// Padding characters (`=`) interspersed in the encoded form will be treated as invalid bytes. + /// + /// Padding characters (`=`) interspersed in the encoded form are invalid, as they may only + /// be present as the last 0-2 bytes of input. + /// + /// This error may also indicate that extraneous trailing input bytes are present, causing + /// otherwise valid padding to no longer be the last bytes of input. InvalidByte(usize, u8), - /// The length of the input is invalid. - /// A typical cause of this is stray trailing whitespace or other separator bytes. - /// In the case where excess trailing bytes have produced an invalid length *and* the last byte - /// is also an invalid base64 symbol (as would be the case for whitespace, etc), `InvalidByte` - /// will be emitted instead of `InvalidLength` to make the issue easier to debug. - InvalidLength, + /// The length of the input, as measured in valid base64 symbols, is invalid. + /// There must be 2-4 symbols in the last input quad. + InvalidLength(usize), /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded. /// This is indicative of corrupted or truncated Base64. - /// Unlike `InvalidByte`, which reports symbols that aren't in the alphabet, this error is for - /// symbols that are in the alphabet but represent nonsensical encodings. + /// Unlike [DecodeError::InvalidByte], which reports symbols that aren't in the alphabet, + /// this error is for symbols that are in the alphabet but represent nonsensical encodings. InvalidLastSymbol(usize, u8), /// The nature of the padding was not as configured: absent or incorrect when it must be /// canonical, or present when it must be absent, etc. @@ -30,8 +32,10 @@ pub enum DecodeError { impl fmt::Display for DecodeError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Self::InvalidByte(index, byte) => write!(f, "Invalid byte {}, offset {}.", byte, index), - Self::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."), + Self::InvalidByte(index, byte) => { + write!(f, "Invalid symbol {}, offset {}.", byte, index) + } + Self::InvalidLength(len) => write!(f, "Invalid input length: {}", len), Self::InvalidLastSymbol(index, byte) => { write!(f, "Invalid last symbol {}, offset {}.", byte, index) } @@ -48,9 +52,7 @@ impl error::Error for DecodeError {} pub enum DecodeSliceError { /// A [DecodeError] occurred DecodeError(DecodeError), - /// The provided slice _may_ be too small. - /// - /// The check is conservative (assumes the last triplet of output bytes will all be needed). + /// The provided slice is too small. OutputSliceTooSmall, } @@ -338,3 +340,47 @@ mod tests { } } } + +#[allow(deprecated)] +#[cfg(test)] +mod coverage_gaming { + use super::*; + use std::error::Error; + + #[test] + fn decode_error() { + let _ = format!("{:?}", DecodeError::InvalidPadding.clone()); + let _ = format!( + "{} {} {} {}", + DecodeError::InvalidByte(0, 0), + DecodeError::InvalidLength(0), + DecodeError::InvalidLastSymbol(0, 0), + DecodeError::InvalidPadding, + ); + } + + #[test] + fn decode_slice_error() { + let _ = format!("{:?}", DecodeSliceError::OutputSliceTooSmall.clone()); + let _ = format!( + "{} {}", + DecodeSliceError::OutputSliceTooSmall, + DecodeSliceError::DecodeError(DecodeError::InvalidPadding) + ); + let _ = DecodeSliceError::OutputSliceTooSmall.source(); + let _ = DecodeSliceError::DecodeError(DecodeError::InvalidPadding).source(); + } + + #[test] + fn deprecated_fns() { + let _ = decode(""); + let _ = decode_engine("", &crate::prelude::BASE64_STANDARD); + let _ = decode_engine_vec("", &mut Vec::new(), &crate::prelude::BASE64_STANDARD); + let _ = decode_engine_slice("", &mut [], &crate::prelude::BASE64_STANDARD); + } + + #[test] + fn decoded_len_est() { + assert_eq!(3, decoded_len_estimate(4)); + } +} diff --git a/src/engine/general_purpose/decode.rs b/src/engine/general_purpose/decode.rs index 21a386f..b55d3fc 100644 --- a/src/engine/general_purpose/decode.rs +++ b/src/engine/general_purpose/decode.rs @@ -1,47 +1,28 @@ use crate::{ engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode}, - DecodeError, PAD_BYTE, + DecodeError, DecodeSliceError, PAD_BYTE, }; -// decode logic operates on chunks of 8 input bytes without padding -const INPUT_CHUNK_LEN: usize = 8; -const DECODED_CHUNK_LEN: usize = 6; - -// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last -// 2 bytes of any output u64 should not be counted as written to (but must be available in a -// slice). -const DECODED_CHUNK_SUFFIX: usize = 2; - -// how many u64's of input to handle at a time -const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4; - -const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN; - -// includes the trailing 2 bytes for the final u64 write -const DECODED_BLOCK_LEN: usize = - CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX; - #[doc(hidden)] pub struct GeneralPurposeEstimate { - /// Total number of decode chunks, including a possibly partial last chunk - num_chunks: usize, - decoded_len_estimate: usize, + /// input len % 4 + rem: usize, + conservative_decoded_len: usize, } impl GeneralPurposeEstimate { pub(crate) fn new(encoded_len: usize) -> Self { - // Formulas that won't overflow + let rem = encoded_len % 4; Self { - num_chunks: encoded_len / INPUT_CHUNK_LEN - + (encoded_len % INPUT_CHUNK_LEN > 0) as usize, - decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3, + rem, + conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3, } } } impl DecodeEstimate for GeneralPurposeEstimate { fn decoded_len_estimate(&self) -> usize { - self.decoded_len_estimate + self.conservative_decoded_len } } @@ -58,263 +39,262 @@ pub(crate) fn decode_helper( decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, -) -> Result { - let remainder_len = input.len() % INPUT_CHUNK_LEN; - - // Because the fast decode loop writes in groups of 8 bytes (unrolled to - // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of - // which only 6 are valid data), we need to be sure that we stop using the fast decode loop - // soon enough that there will always be 2 more bytes of valid data written after that loop. - let trailing_bytes_to_skip = match remainder_len { - // if input is a multiple of the chunk size, ignore the last chunk as it may have padding, - // and the fast decode logic cannot handle padding - 0 => INPUT_CHUNK_LEN, - // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte - 1 | 5 => { - // trailing whitespace is so common that it's worth it to check the last byte to - // possibly return a better error message - if let Some(b) = input.last() { - if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE { - return Err(DecodeError::InvalidByte(input.len() - 1, *b)); - } - } +) -> Result { + let input_complete_nonterminal_quads_len = + complete_quads_len(input, estimate.rem, output.len(), decode_table)?; - return Err(DecodeError::InvalidLength); - } - // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes - // written by the fast decode loop. So, we have to ignore both these 2 bytes and the - // previous chunk. - 2 => INPUT_CHUNK_LEN + 2, - // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this - // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail - // with an error, not panic from going past the bounds of the output slice, so we let it - // use stage 3 + 4. - 3 => INPUT_CHUNK_LEN + 3, - // This can also decode to one output byte because it may be 2 input chars + 2 padding - // chars, which would decode to 1 byte. - 4 => INPUT_CHUNK_LEN + 4, - // Everything else is a legal decode len (given that we don't require padding), and will - // decode to at least 2 bytes of output. - _ => remainder_len, - }; + const UNROLLED_INPUT_CHUNK_SIZE: usize = 32; + const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3; - // rounded up to include partial chunks - let mut remaining_chunks = estimate.num_chunks; + let input_complete_quads_after_unrolled_chunks_len = + input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE; - let mut input_index = 0; - let mut output_index = 0; + let input_unrolled_loop_len = + input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len; + // chunks of 32 bytes + for (chunk_index, chunk) in input[..input_unrolled_loop_len] + .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE) + .enumerate() { - let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip); - - // Fast loop, stage 1 - // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks - if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) { - while input_index <= max_start_index { - let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)]; - let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)]; - - decode_chunk( - &input_slice[0..], - input_index, - decode_table, - &mut output_slice[0..], - )?; - decode_chunk( - &input_slice[8..], - input_index + 8, - decode_table, - &mut output_slice[6..], - )?; - decode_chunk( - &input_slice[16..], - input_index + 16, - decode_table, - &mut output_slice[12..], - )?; - decode_chunk( - &input_slice[24..], - input_index + 24, - decode_table, - &mut output_slice[18..], - )?; - - input_index += INPUT_BLOCK_LEN; - output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX; - remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK; - } - } - - // Fast loop, stage 2 (aka still pretty fast loop) - // 8 bytes at a time for whatever we didn't do in stage 1. - if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) { - while input_index < max_start_index { - decode_chunk( - &input[input_index..(input_index + INPUT_CHUNK_LEN)], - input_index, - decode_table, - &mut output - [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)], - )?; - - output_index += DECODED_CHUNK_LEN; - input_index += INPUT_CHUNK_LEN; - remaining_chunks -= 1; - } - } - } + let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE; + let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE + ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE]; - // Stage 3 - // If input length was such that a chunk had to be deferred until after the fast loop - // because decoding it would have produced 2 trailing bytes that wouldn't then be - // overwritten, we decode that chunk here. This way is slower but doesn't write the 2 - // trailing bytes. - // However, we still need to avoid the last chunk (partial or complete) because it could - // have padding, so we always do 1 fewer to avoid the last chunk. - for _ in 1..remaining_chunks { - decode_chunk_precise( - &input[input_index..], + decode_chunk_8( + &chunk[0..8], input_index, decode_table, - &mut output[output_index..(output_index + DECODED_CHUNK_LEN)], + &mut chunk_output[0..6], + )?; + decode_chunk_8( + &chunk[8..16], + input_index + 8, + decode_table, + &mut chunk_output[6..12], + )?; + decode_chunk_8( + &chunk[16..24], + input_index + 16, + decode_table, + &mut chunk_output[12..18], + )?; + decode_chunk_8( + &chunk[24..32], + input_index + 24, + decode_table, + &mut chunk_output[18..24], )?; - - input_index += INPUT_CHUNK_LEN; - output_index += DECODED_CHUNK_LEN; } - // always have one more (possibly partial) block of 8 input - debug_assert!(input.len() - input_index > 1 || input.is_empty()); - debug_assert!(input.len() - input_index <= 8); + // remaining quads, except for the last possibly partial one, as it may have padding + let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3; + let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3; + { + let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len]; + + for (chunk_index, chunk) in input + [input_unrolled_loop_len..input_complete_nonterminal_quads_len] + .chunks_exact(4) + .enumerate() + { + let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3]; + + decode_chunk_4( + chunk, + input_unrolled_loop_len + chunk_index * 4, + decode_table, + chunk_output, + )?; + } + } super::decode_suffix::decode_suffix( input, - input_index, + input_complete_nonterminal_quads_len, output, - output_index, + output_complete_quad_len, decode_table, decode_allow_trailing_bits, padding_mode, ) } -/// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the -/// first 6 of those contain meaningful data. +/// Returns the length of complete quads, except for the last one, even if it is complete. +/// +/// Returns an error if the output len is not big enough for decoding those complete quads, or if +/// the input % 4 == 1, and that last byte is an invalid value other than a pad byte. +/// +/// - `input` is the base64 input +/// - `input_len_rem` is input len % 4 +/// - `output_len` is the length of the output slice +pub(crate) fn complete_quads_len( + input: &[u8], + input_len_rem: usize, + output_len: usize, + decode_table: &[u8; 256], +) -> Result { + debug_assert!(input.len() % 4 == input_len_rem); + + // detect a trailing invalid byte, like a newline, as a user convenience + if input_len_rem == 1 { + let last_byte = input[input.len() - 1]; + // exclude pad bytes; might be part of padding that extends from earlier in the input + if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE { + return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into()); + } + }; + + // skip last quad, even if it's complete, as it may have padding + let input_complete_nonterminal_quads_len = input + .len() + .saturating_sub(input_len_rem) + // if rem was 0, subtract 4 to avoid padding + .saturating_sub((input_len_rem == 0) as usize * 4); + debug_assert!( + input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len)) + ); + + // check that everything except the last quad handled by decode_suffix will fit + if output_len < input_complete_nonterminal_quads_len / 4 * 3 { + return Err(DecodeSliceError::OutputSliceTooSmall); + }; + Ok(input_complete_nonterminal_quads_len) +} + +/// Decode 8 bytes of input into 6 bytes of output. /// -/// `input` is the bytes to decode, of which the first 8 bytes will be processed. +/// `input` is the 8 bytes to decode. /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors /// accurately) /// `decode_table` is the lookup table for the particular base64 alphabet. -/// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded -/// data. +/// `output` will have its first 6 bytes overwritten // yes, really inline (worth 30-50% speedup) #[inline(always)] -fn decode_chunk( +fn decode_chunk_8( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError> { - let morsel = decode_table[input[0] as usize]; + let morsel = decode_table[usize::from(input[0])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); } - let mut accum = (morsel as u64) << 58; + let mut accum = u64::from(morsel) << 58; - let morsel = decode_table[input[1] as usize]; + let morsel = decode_table[usize::from(input[1])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 1, input[1], )); } - accum |= (morsel as u64) << 52; + accum |= u64::from(morsel) << 52; - let morsel = decode_table[input[2] as usize]; + let morsel = decode_table[usize::from(input[2])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 2, input[2], )); } - accum |= (morsel as u64) << 46; + accum |= u64::from(morsel) << 46; - let morsel = decode_table[input[3] as usize]; + let morsel = decode_table[usize::from(input[3])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 3, input[3], )); } - accum |= (morsel as u64) << 40; + accum |= u64::from(morsel) << 40; - let morsel = decode_table[input[4] as usize]; + let morsel = decode_table[usize::from(input[4])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 4, input[4], )); } - accum |= (morsel as u64) << 34; + accum |= u64::from(morsel) << 34; - let morsel = decode_table[input[5] as usize]; + let morsel = decode_table[usize::from(input[5])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 5, input[5], )); } - accum |= (morsel as u64) << 28; + accum |= u64::from(morsel) << 28; - let morsel = decode_table[input[6] as usize]; + let morsel = decode_table[usize::from(input[6])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 6, input[6], )); } - accum |= (morsel as u64) << 22; + accum |= u64::from(morsel) << 22; - let morsel = decode_table[input[7] as usize]; + let morsel = decode_table[usize::from(input[7])]; if morsel == INVALID_VALUE { return Err(DecodeError::InvalidByte( index_at_start_of_input + 7, input[7], )); } - accum |= (morsel as u64) << 16; + accum |= u64::from(morsel) << 16; - write_u64(output, accum); + output[..6].copy_from_slice(&accum.to_be_bytes()[..6]); Ok(()) } -/// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2 -/// trailing garbage bytes. -#[inline] -fn decode_chunk_precise( +/// Like [decode_chunk_8] but for 4 bytes of input and 3 bytes of output. +#[inline(always)] +fn decode_chunk_4( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError> { - let mut tmp_buf = [0_u8; 8]; + let morsel = decode_table[usize::from(input[0])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); + } + let mut accum = u32::from(morsel) << 26; - decode_chunk( - input, - index_at_start_of_input, - decode_table, - &mut tmp_buf[..], - )?; + let morsel = decode_table[usize::from(input[1])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 1, + input[1], + )); + } + accum |= u32::from(morsel) << 20; + + let morsel = decode_table[usize::from(input[2])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 2, + input[2], + )); + } + accum |= u32::from(morsel) << 14; - output[0..6].copy_from_slice(&tmp_buf[0..6]); + let morsel = decode_table[usize::from(input[3])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 3, + input[3], + )); + } + accum |= u32::from(morsel) << 8; - Ok(()) -} + output[..3].copy_from_slice(&accum.to_be_bytes()[..3]); -#[inline] -fn write_u64(output: &mut [u8], value: u64) { - output[..8].copy_from_slice(&value.to_be_bytes()); + Ok(()) } #[cfg(test)] @@ -324,37 +304,36 @@ mod tests { use crate::engine::general_purpose::STANDARD; #[test] - fn decode_chunk_precise_writes_only_6_bytes() { + fn decode_chunk_8_writes_only_6_bytes() { let input = b"Zm9vYmFy"; // "foobar" let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; - decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output); } #[test] - fn decode_chunk_writes_8_bytes() { - let input = b"Zm9vYmFy"; // "foobar" - let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; + fn decode_chunk_4_writes_only_3_bytes() { + let input = b"Zm9v"; // "foobar" + let mut output = [0_u8, 1, 2, 3]; - decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); - assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output); + decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + assert_eq!(&vec![b'f', b'o', b'o', 3], &output); } #[test] fn estimate_short_lengths() { - for (range, (num_chunks, decoded_len_estimate)) in [ - (0..=0, (0, 0)), - (1..=4, (1, 3)), - (5..=8, (1, 6)), - (9..=12, (2, 9)), - (13..=16, (2, 12)), - (17..=20, (3, 15)), + for (range, decoded_len_estimate) in [ + (0..=0, 0), + (1..=4, 3), + (5..=8, 6), + (9..=12, 9), + (13..=16, 12), + (17..=20, 15), ] { for encoded_len in range { let estimate = GeneralPurposeEstimate::new(encoded_len); - assert_eq!(num_chunks, estimate.num_chunks); - assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate); + assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate()); } } } @@ -370,13 +349,8 @@ mod tests { let estimate = GeneralPurposeEstimate::new(encoded_len); assert_eq!( - ((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128)) - as usize, - estimate.num_chunks - ); - assert_eq!( - ((len_128 + 3) / 4 * 3) as usize, - estimate.decoded_len_estimate + (len_128 + 3) / 4 * 3, + estimate.conservative_decoded_len as u128 ); }) } diff --git a/src/engine/general_purpose/decode_suffix.rs b/src/engine/general_purpose/decode_suffix.rs index e1e005d..02aaf51 100644 --- a/src/engine/general_purpose/decode_suffix.rs +++ b/src/engine/general_purpose/decode_suffix.rs @@ -1,9 +1,9 @@ use crate::{ engine::{general_purpose::INVALID_VALUE, DecodeMetadata, DecodePaddingMode}, - DecodeError, PAD_BYTE, + DecodeError, DecodeSliceError, PAD_BYTE, }; -/// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided +/// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided /// parameters. /// /// Returns the decode metadata representing the total number of bytes decoded, including the ones @@ -16,17 +16,19 @@ pub(crate) fn decode_suffix( decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, -) -> Result { - // Decode any leftovers that aren't a complete input block of 8 bytes. - // Use a u64 as a stack-resident 8 byte buffer. - let mut leftover_bits: u64 = 0; +) -> Result { + debug_assert!((input.len() - input_index) <= 4); + + // Decode any leftovers that might not be a complete input chunk of 4 bytes. + // Use a u32 as a stack-resident 4 byte buffer. let mut morsels_in_leftover = 0; - let mut padding_bytes = 0; - let mut first_padding_index: usize = 0; + let mut padding_bytes_count = 0; + // offset from input_index + let mut first_padding_offset: usize = 0; let mut last_symbol = 0_u8; - let start_of_leftovers = input_index; + let mut morsels = [0_u8; 4]; - for (i, &b) in input[start_of_leftovers..].iter().enumerate() { + for (leftover_index, &b) in input[input_index..].iter().enumerate() { // '=' padding if b == PAD_BYTE { // There can be bad padding bytes in a few ways: @@ -41,30 +43,22 @@ pub(crate) fn decode_suffix( // Per config, non-canonical but still functional non- or partially-padded base64 // may be treated as an error condition. - if i % 4 < 2 { - // Check for case #2. - let bad_padding_index = start_of_leftovers - + if padding_bytes > 0 { - // If we've already seen padding, report the first padding index. - // This is to be consistent with the normal decode logic: it will report an - // error on the first padding character (since it doesn't expect to see - // anything but actual encoded data). - // This could only happen if the padding started in the previous quad since - // otherwise this case would have been hit at i % 4 == 0 if it was the same - // quad. - first_padding_index - } else { - // haven't seen padding before, just use where we are now - i - }; - return Err(DecodeError::InvalidByte(bad_padding_index, b)); + if leftover_index < 2 { + // Check for error #2. + // Either the previous byte was padding, in which case we would have already hit + // this case, or it wasn't, in which case this is the first such error. + debug_assert!( + leftover_index == 0 || (leftover_index == 1 && padding_bytes_count == 0) + ); + let bad_padding_index = input_index + leftover_index; + return Err(DecodeError::InvalidByte(bad_padding_index, b).into()); } - if padding_bytes == 0 { - first_padding_index = i; + if padding_bytes_count == 0 { + first_padding_offset = leftover_index; } - padding_bytes += 1; + padding_bytes_count += 1; continue; } @@ -72,39 +66,44 @@ pub(crate) fn decode_suffix( // To make '=' handling consistent with the main loop, don't allow // non-suffix '=' in trailing chunk either. Report error as first // erroneous padding. - if padding_bytes > 0 { - return Err(DecodeError::InvalidByte( - start_of_leftovers + first_padding_index, - PAD_BYTE, - )); + if padding_bytes_count > 0 { + return Err( + DecodeError::InvalidByte(input_index + first_padding_offset, PAD_BYTE).into(), + ); } last_symbol = b; // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding. // Pack the leftovers from left to right. - let shift = 64 - (morsels_in_leftover + 1) * 6; let morsel = decode_table[b as usize]; if morsel == INVALID_VALUE { - return Err(DecodeError::InvalidByte(start_of_leftovers + i, b)); + return Err(DecodeError::InvalidByte(input_index + leftover_index, b).into()); } - leftover_bits |= (morsel as u64) << shift; + morsels[morsels_in_leftover] = morsel; morsels_in_leftover += 1; } + // If there was 1 trailing byte, and it was valid, and we got to this point without hitting + // an invalid byte, now we can report invalid length + if !input.is_empty() && morsels_in_leftover < 2 { + return Err(DecodeError::InvalidLength(input_index + morsels_in_leftover).into()); + } + match padding_mode { DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ } DecodePaddingMode::RequireCanonical => { - if (padding_bytes + morsels_in_leftover) % 4 != 0 { - return Err(DecodeError::InvalidPadding); + // allow empty input + if (padding_bytes_count + morsels_in_leftover) % 4 != 0 { + return Err(DecodeError::InvalidPadding.into()); } } DecodePaddingMode::RequireNone => { - if padding_bytes > 0 { + if padding_bytes_count > 0 { // check at the end to make sure we let the cases of padding that should be InvalidByte // get hit - return Err(DecodeError::InvalidPadding); + return Err(DecodeError::InvalidPadding.into()); } } } @@ -117,50 +116,45 @@ pub(crate) fn decode_suffix( // bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a // mask based on how many bits are used for just the canonical encoding, and optionally // error if any other bits are set. In the example of one encoded byte -> 2 symbols, - // 2 symbols can technically encode 12 bits, but the last 4 are non canonical, and + // 2 symbols can technically encode 12 bits, but the last 4 are non-canonical, and // useless since there are no more symbols to provide the necessary 4 additional bits // to finish the second original byte. - let leftover_bits_ready_to_append = match morsels_in_leftover { - 0 => 0, - 2 => 8, - 3 => 16, - 4 => 24, - 6 => 32, - 7 => 40, - 8 => 48, - // can also be detected as case #2 bad padding above - _ => unreachable!( - "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths" - ), - }; + let leftover_bytes_to_append = morsels_in_leftover * 6 / 8; + // Put the up to 6 complete bytes as the high bytes. + // Gain a couple percent speedup from nudging these ORs to use more ILP with a two-way split. + let mut leftover_num = (u32::from(morsels[0]) << 26) + | (u32::from(morsels[1]) << 20) + | (u32::from(morsels[2]) << 14) + | (u32::from(morsels[3]) << 8); // if there are bits set outside the bits we care about, last symbol encodes trailing bits that // will not be included in the output - let mask = !0 >> leftover_bits_ready_to_append; - if !decode_allow_trailing_bits && (leftover_bits & mask) != 0 { + let mask = !0_u32 >> (leftover_bytes_to_append * 8); + if !decode_allow_trailing_bits && (leftover_num & mask) != 0 { // last morsel is at `morsels_in_leftover` - 1 return Err(DecodeError::InvalidLastSymbol( - start_of_leftovers + morsels_in_leftover - 1, + input_index + morsels_in_leftover - 1, last_symbol, - )); + ) + .into()); } - // TODO benchmark simply converting to big endian bytes - let mut leftover_bits_appended_to_buf = 0; - while leftover_bits_appended_to_buf < leftover_bits_ready_to_append { - // `as` simply truncates the higher bits, which is what we want here - let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8; - output[output_index] = selected_bits; + // Strangely, this approach benchmarks better than writing bytes one at a time, + // or copy_from_slice into output. + for _ in 0..leftover_bytes_to_append { + let hi_byte = (leftover_num >> 24) as u8; + leftover_num <<= 8; + *output + .get_mut(output_index) + .ok_or(DecodeSliceError::OutputSliceTooSmall)? = hi_byte; output_index += 1; - - leftover_bits_appended_to_buf += 8; } Ok(DecodeMetadata::new( output_index, - if padding_bytes > 0 { - Some(input_index + first_padding_index) + if padding_bytes_count > 0 { + Some(input_index + first_padding_offset) } else { None }, diff --git a/src/engine/general_purpose/mod.rs b/src/engine/general_purpose/mod.rs index e0227f3..6fe9580 100644 --- a/src/engine/general_purpose/mod.rs +++ b/src/engine/general_purpose/mod.rs @@ -3,11 +3,11 @@ use crate::{ alphabet, alphabet::Alphabet, engine::{Config, DecodeMetadata, DecodePaddingMode}, - DecodeError, + DecodeSliceError, }; use core::convert::TryInto; -mod decode; +pub(crate) mod decode; pub(crate) mod decode_suffix; pub use decode::GeneralPurposeEstimate; @@ -173,7 +173,7 @@ impl super::Engine for GeneralPurpose { input: &[u8], output: &mut [u8], estimate: Self::DecodeEstimate, - ) -> Result { + ) -> Result { decode::decode_helper( input, estimate, diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 16c05d7..77dcd14 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -83,17 +83,13 @@ pub trait Engine: Send + Sync { /// /// Non-canonical trailing bits in the final tokens or non-canonical padding must be reported as /// errors unless the engine is configured otherwise. - /// - /// # Panics - /// - /// Panics if `output` is too small. #[doc(hidden)] fn internal_decode( &self, input: &[u8], output: &mut [u8], decode_estimate: Self::DecodeEstimate, - ) -> Result; + ) -> Result; /// Returns the config for this engine. fn config(&self) -> &Self::Config; @@ -253,7 +249,13 @@ pub trait Engine: Send + Sync { let mut buffer = vec![0; estimate.decoded_len_estimate()]; let bytes_written = engine - .internal_decode(input_bytes, &mut buffer, estimate)? + .internal_decode(input_bytes, &mut buffer, estimate) + .map_err(|e| match e { + DecodeSliceError::DecodeError(e) => e, + DecodeSliceError::OutputSliceTooSmall => { + unreachable!("Vec is sized conservatively") + } + })? .decoded_len; buffer.truncate(bytes_written); @@ -318,7 +320,13 @@ pub trait Engine: Send + Sync { let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..]; let bytes_written = engine - .internal_decode(input_bytes, buffer_slice, estimate)? + .internal_decode(input_bytes, buffer_slice, estimate) + .map_err(|e| match e { + DecodeSliceError::DecodeError(e) => e, + DecodeSliceError::OutputSliceTooSmall => { + unreachable!("Vec is sized conservatively") + } + })? .decoded_len; buffer.truncate(starting_output_len + bytes_written); @@ -354,15 +362,12 @@ pub trait Engine: Send + Sync { where E: Engine + ?Sized, { - let estimate = engine.internal_decoded_len_estimate(input_bytes.len()); - - if output.len() < estimate.decoded_len_estimate() { - return Err(DecodeSliceError::OutputSliceTooSmall); - } - engine - .internal_decode(input_bytes, output, estimate) - .map_err(|e| e.into()) + .internal_decode( + input_bytes, + output, + engine.internal_decoded_len_estimate(input_bytes.len()), + ) .map(|dm| dm.decoded_len) } @@ -400,6 +405,12 @@ pub trait Engine: Send + Sync { engine.internal_decoded_len_estimate(input_bytes.len()), ) .map(|dm| dm.decoded_len) + .map_err(|e| match e { + DecodeSliceError::DecodeError(e) => e, + DecodeSliceError::OutputSliceTooSmall => { + panic!("Output slice is too small") + } + }) } inner(self, input.as_ref(), output) diff --git a/src/engine/naive.rs b/src/engine/naive.rs index 6a50cbe..af509bf 100644 --- a/src/engine/naive.rs +++ b/src/engine/naive.rs @@ -4,7 +4,7 @@ use crate::{ general_purpose::{self, decode_table, encode_table}, Config, DecodeEstimate, DecodeMetadata, DecodePaddingMode, Engine, }, - DecodeError, PAD_BYTE, + DecodeError, DecodeSliceError, }; use std::ops::{BitAnd, BitOr, Shl, Shr}; @@ -111,63 +111,40 @@ impl Engine for Naive { input: &[u8], output: &mut [u8], estimate: Self::DecodeEstimate, - ) -> Result { - if estimate.rem == 1 { - // trailing whitespace is so common that it's worth it to check the last byte to - // possibly return a better error message - if let Some(b) = input.last() { - if *b != PAD_BYTE - && self.decode_table[*b as usize] == general_purpose::INVALID_VALUE - { - return Err(DecodeError::InvalidByte(input.len() - 1, *b)); - } - } - - return Err(DecodeError::InvalidLength); - } + ) -> Result { + let complete_nonterminal_quads_len = general_purpose::decode::complete_quads_len( + input, + estimate.rem, + output.len(), + &self.decode_table, + )?; - let mut input_index = 0_usize; - let mut output_index = 0_usize; const BOTTOM_BYTE: u32 = 0xFF; - // can only use the main loop on non-trailing chunks - if input.len() > Self::DECODE_INPUT_CHUNK_SIZE { - // skip the last chunk, whether it's partial or full, since it might - // have padding, and start at the beginning of the chunk before that - let last_complete_chunk_start_index = estimate.complete_chunk_len - - if estimate.rem == 0 { - // Trailing chunk is also full chunk, so there must be at least 2 chunks, and - // this won't underflow - Self::DECODE_INPUT_CHUNK_SIZE * 2 - } else { - // Trailing chunk is partial, so it's already excluded in - // complete_chunk_len - Self::DECODE_INPUT_CHUNK_SIZE - }; - - while input_index <= last_complete_chunk_start_index { - let chunk = &input[input_index..input_index + Self::DECODE_INPUT_CHUNK_SIZE]; - let decoded_int: u32 = self.decode_byte_into_u32(input_index, chunk[0])?.shl(18) - | self - .decode_byte_into_u32(input_index + 1, chunk[1])? - .shl(12) - | self.decode_byte_into_u32(input_index + 2, chunk[2])?.shl(6) - | self.decode_byte_into_u32(input_index + 3, chunk[3])?; - - output[output_index] = decoded_int.shr(16_u8).bitand(BOTTOM_BYTE) as u8; - output[output_index + 1] = decoded_int.shr(8_u8).bitand(BOTTOM_BYTE) as u8; - output[output_index + 2] = decoded_int.bitand(BOTTOM_BYTE) as u8; - - input_index += Self::DECODE_INPUT_CHUNK_SIZE; - output_index += 3; - } + for (chunk_index, chunk) in input[..complete_nonterminal_quads_len] + .chunks_exact(4) + .enumerate() + { + let input_index = chunk_index * 4; + let output_index = chunk_index * 3; + + let decoded_int: u32 = self.decode_byte_into_u32(input_index, chunk[0])?.shl(18) + | self + .decode_byte_into_u32(input_index + 1, chunk[1])? + .shl(12) + | self.decode_byte_into_u32(input_index + 2, chunk[2])?.shl(6) + | self.decode_byte_into_u32(input_index + 3, chunk[3])?; + + output[output_index] = decoded_int.shr(16_u8).bitand(BOTTOM_BYTE) as u8; + output[output_index + 1] = decoded_int.shr(8_u8).bitand(BOTTOM_BYTE) as u8; + output[output_index + 2] = decoded_int.bitand(BOTTOM_BYTE) as u8; } general_purpose::decode_suffix::decode_suffix( input, - input_index, + complete_nonterminal_quads_len, output, - output_index, + complete_nonterminal_quads_len / 4 * 3, &self.decode_table, self.config.decode_allow_trailing_bits, self.config.decode_padding_mode, diff --git a/src/engine/tests.rs b/src/engine/tests.rs index b048005..72bbf4b 100644 --- a/src/engine/tests.rs +++ b/src/engine/tests.rs @@ -19,7 +19,7 @@ use crate::{ }, read::DecoderReader, tests::{assert_encode_sanity, random_alphabet, random_config}, - DecodeError, PAD_BYTE, + DecodeError, DecodeSliceError, PAD_BYTE, }; // the case::foo syntax includes the "foo" in the generated test method names @@ -365,26 +365,49 @@ fn decode_detect_invalid_last_symbol(engine_wrapper: E) { } #[apply(all_engines)] -fn decode_detect_invalid_last_symbol_when_length_is_also_invalid( - engine_wrapper: E, -) { - let mut rng = seeded_rng(); - - // check across enough lengths that it would likely cover any implementation's various internal - // small/large input division +fn decode_detect_1_valid_symbol_in_last_quad_invalid_length(engine_wrapper: E) { for len in (0_usize..256).map(|len| len * 4 + 1) { - let engine = E::random_alphabet(&mut rng, &STANDARD); + for mode in all_pad_modes() { + let mut input = vec![b'A'; len]; - let mut input = vec![b'A'; len]; + let engine = E::standard_with_pad_mode(true, mode); - // with a valid last char, it's InvalidLength - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&input)); - // after mangling the last char, it's InvalidByte - input[len - 1] = b'"'; - assert_eq!( - Err(DecodeError::InvalidByte(len - 1, b'"')), - engine.decode(&input) - ); + assert_eq!(Err(DecodeError::InvalidLength(len)), engine.decode(&input)); + // if we add padding, then the first pad byte in the quad is invalid because it should + // be the second symbol + for _ in 0..3 { + input.push(PAD_BYTE); + assert_eq!( + Err(DecodeError::InvalidByte(len, PAD_BYTE)), + engine.decode(&input) + ); + } + } + } +} + +#[apply(all_engines)] +fn decode_detect_1_invalid_byte_in_last_quad_invalid_byte(engine_wrapper: E) { + for prefix_len in (0_usize..256).map(|len| len * 4) { + for mode in all_pad_modes() { + let mut input = vec![b'A'; prefix_len]; + input.push(b'*'); + + let engine = E::standard_with_pad_mode(true, mode); + + assert_eq!( + Err(DecodeError::InvalidByte(prefix_len, b'*')), + engine.decode(&input) + ); + // adding padding doesn't matter + for _ in 0..3 { + input.push(PAD_BYTE); + assert_eq!( + Err(DecodeError::InvalidByte(prefix_len, b'*')), + engine.decode(&input) + ); + } + } } } @@ -471,8 +494,10 @@ fn decode_detect_invalid_last_symbol_every_possible_three_symbols(engine_wrapper: E) { /// Any amount of padding anywhere before the final non padding character = invalid byte at first /// pad byte. -/// From this, we know padding must extend to the end of the input. -// DecoderReader pseudo-engine detects InvalidLastSymbol instead of InvalidLength because it -// can end a decode on the quad that happens to contain the start of the padding -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_before_final_non_padding_char_error_invalid_byte( +/// From this and [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix_all_modes], +/// we know padding must extend contiguously to the end of the input. +#[apply(all_engines)] +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes< + E: EngineWrapper, +>( engine_wrapper: E, ) { - let mut rng = seeded_rng(); + // Different amounts of padding, w/ offset from end for the last non-padding char. + // Only canonical padding, so Canonical mode will work. + let suffixes = &[("AA==", 2), ("AAA=", 1), ("AAAA", 0)]; - // the different amounts of proper padding, w/ offset from end for the last non-padding char - let suffixes = [("/w==", 2), ("iYu=", 1), ("zzzz", 0)]; + for mode in pad_modes_allowing_padding() { + // We don't encode, so we don't care about encode padding. + let engine = E::standard_with_pad_mode(true, mode); - let prefix_quads_range = distributions::Uniform::from(0..=256); + decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine, + suffixes.as_slice(), + ); + } +} - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); +/// See [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes] +#[apply(all_engines)] +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix< + E: EngineWrapper, +>( + engine_wrapper: E, +) { + // Different amounts of padding, w/ offset from end for the last non-padding char, and + // non-canonical padding. + let suffixes = [ + ("AA==", 2), + ("AA=", 1), + ("AA", 0), + ("AAA=", 1), + ("AAA", 0), + ("AAAA", 0), + ]; - for _ in 0..100_000 { - for (suffix, offset) in suffixes.iter() { - let mut s = "ABCD".repeat(prefix_quads_range.sample(&mut rng)); - s.push_str(suffix); - let mut encoded = s.into_bytes(); + // We don't encode, so we don't care about encode padding. + // Decoding is indifferent so that we don't get caught by missing padding on the last quad + let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent); - // calculate a range to write padding into that leaves at least one non padding char - let last_non_padding_offset = encoded.len() - 1 - offset; + decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine, + suffixes.as_slice(), + ) +} - // don't include last non padding char as it must stay not padding - let padding_end = rng.gen_range(0..last_non_padding_offset); +fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( + engine: impl Engine, + suffixes: &[(&str, usize)], +) { + let mut rng = seeded_rng(); - // don't use more than 100 bytes of padding, but also use shorter lengths when - // padding_end is near the start of the encoded data to avoid biasing to padding - // the entire prefix on short lengths - let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1)); - let padding_start = padding_end.saturating_sub(padding_len); + let prefix_quads_range = distributions::Uniform::from(0..=256); - encoded[padding_start..=padding_end].fill(PAD_BYTE); + for _ in 0..100_000 { + for (suffix, suffix_offset) in suffixes.iter() { + let mut s = "AAAA".repeat(prefix_quads_range.sample(&mut rng)); + s.push_str(suffix); + let mut encoded = s.into_bytes(); - assert_eq!( - Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), - engine.decode(&encoded), - ); - } + // calculate a range to write padding into that leaves at least one non padding char + let last_non_padding_offset = encoded.len() - 1 - suffix_offset; + + // don't include last non padding char as it must stay not padding + let padding_end = rng.gen_range(0..last_non_padding_offset); + + // don't use more than 100 bytes of padding, but also use shorter lengths when + // padding_end is near the start of the encoded data to avoid biasing to padding + // the entire prefix on short lengths + let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1)); + let padding_start = padding_end.saturating_sub(padding_len); + + encoded[padding_start..=padding_end].fill(PAD_BYTE); + + // should still have non-padding before any final padding + assert_ne!(PAD_BYTE, encoded[last_non_padding_offset]); + assert_eq!( + Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)), + engine.decode(&encoded), + "len: {}, input: {}", + encoded.len(), + String::from_utf8(encoded).unwrap() + ); } } } -/// Any amount of padding before final chunk that crosses over into final chunk with 2-4 bytes = +/// Any amount of padding before final chunk that crosses over into final chunk with 1-4 bytes = /// invalid byte at first pad byte. -/// From this and [decode_padding_starts_before_final_chunk_error_invalid_length] we know the -/// padding must start in the final chunk. -// DecoderReader pseudo-engine detects InvalidLastSymbol instead of InvalidLength because it -// can end a decode on the quad that happens to contain the start of the padding -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_starts_before_final_chunk_error_invalid_byte( +/// From this we know the padding must start in the final chunk. +#[apply(all_engines)] +fn decode_padding_starts_before_final_chunk_error_invalid_byte_at_first_pad( engine_wrapper: E, ) { let mut rng = seeded_rng(); // must have at least one prefix quad let prefix_quads_range = distributions::Uniform::from(1..256); - // excluding 1 since we don't care about invalid length in this test - let suffix_pad_len_range = distributions::Uniform::from(2..=4); - for mode in all_pad_modes() { + let suffix_pad_len_range = distributions::Uniform::from(1..=4); + // don't use no-padding mode, as the reader decode might decode a block that ends with + // valid padding, which should then be referenced when encountering the later invalid byte + for mode in pad_modes_allowing_padding() { // we don't encode so we don't care about encode padding let engine = E::standard_with_pad_mode(true, mode); for _ in 0..100_000 { let suffix_len = suffix_pad_len_range.sample(&mut rng); - let mut encoded = "ABCD" + // all 0 bits so we don't hit InvalidLastSymbol with the reader decoder + let mut encoded = "AAAA" .repeat(prefix_quads_range.sample(&mut rng)) .into_bytes(); encoded.resize(encoded.len() + suffix_len, PAD_BYTE); @@ -705,40 +774,6 @@ fn decode_padding_starts_before_final_chunk_error_invalid_byte } } -/// Any amount of padding before final chunk that crosses over into final chunk with 1 byte = -/// invalid length. -/// From this we know the padding must start in the final chunk. -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_starts_before_final_chunk_error_invalid_length( - engine_wrapper: E, -) { - let mut rng = seeded_rng(); - - // must have at least one prefix quad - let prefix_quads_range = distributions::Uniform::from(1..256); - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - for _ in 0..100_000 { - let mut encoded = "ABCD" - .repeat(prefix_quads_range.sample(&mut rng)) - .into_bytes(); - encoded.resize(encoded.len() + 1, PAD_BYTE); - - // amount of padding must be long enough to extend back from suffix into previous - // quads - let padding_len = rng.gen_range(1 + 1..encoded.len()); - // no non-padding after padding in this test, so padding goes to the end - let padding_start = encoded.len() - padding_len; - encoded[padding_start..].fill(PAD_BYTE); - - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); - } - } -} - /// 0-1 bytes of data before any amount of padding in final chunk = invalid byte, since padding /// is not valid data (consistent with error for pad bytes in earlier chunks). /// From this we know there must be 2-3 bytes of data before padding @@ -756,29 +791,23 @@ fn decode_too_little_data_before_padding_error_invalid_byte(en let suffix_data_len = suffix_data_len_range.sample(&mut rng); let prefix_quad_len = prefix_quads_range.sample(&mut rng); - // ensure there is a suffix quad - let min_padding = usize::from(suffix_data_len == 0); - // for all possible padding lengths - for padding_len in min_padding..=(4 - suffix_data_len) { + for padding_len in 1..=(4 - suffix_data_len) { let mut encoded = "ABCD".repeat(prefix_quad_len).into_bytes(); encoded.resize(encoded.len() + suffix_data_len, b'A'); encoded.resize(encoded.len() + padding_len, PAD_BYTE); - if suffix_data_len + padding_len == 1 { - assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),); - } else { - assert_eq!( - Err(DecodeError::InvalidByte( - prefix_quad_len * 4 + suffix_data_len, - PAD_BYTE, - )), - engine.decode(&encoded), - "suffix data len {} pad len {}", - suffix_data_len, - padding_len - ); - } + assert_eq!( + Err(DecodeError::InvalidByte( + prefix_quad_len * 4 + suffix_data_len, + PAD_BYTE, + )), + engine.decode(&encoded), + "input {} suffix data len {} pad len {}", + String::from_utf8(encoded).unwrap(), + suffix_data_len, + padding_len + ); } } } @@ -918,258 +947,64 @@ fn decode_pad_mode_indifferent_padding_accepts_anything(engine ); } -//this is a MAY in the rfc: https://tools.ietf.org/html/rfc4648#section-3.3 -// DecoderReader pseudo-engine finds the first padding, but doesn't report it as an error, -// because in the next decode it finds more padding, which is reported as InvalidByte, just -// with an offset at its position in the second decode, rather than being linked to the start -// of the padding that was first seen in the previous decode. -#[apply(all_engines_except_decoder_reader)] -fn decode_pad_byte_in_penultimate_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // leave room for at least one pad byte in penultimate quad - for num_valid_bytes_penultimate_quad in 0..4 { - // can't have 1 or it would be invalid length - for num_pad_bytes_in_final_quad in 2..=4 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - // varying amounts of padding in the penultimate quad - for _ in 0..num_valid_bytes_penultimate_quad { - s.push('A'); - } - // finish penultimate quad with padding - for _ in num_valid_bytes_penultimate_quad..4 { - s.push('='); - } - // and more padding in the final quad - for _ in 0..num_pad_bytes_in_final_quad { - s.push('='); - } - - // padding should be an invalid byte before the final quad. - // Could argue that the *next* padding byte (in the next quad) is technically the first - // erroneous one, but reporting that accurately is more complex and probably nobody cares - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + num_valid_bytes_penultimate_quad, - b'=', - ), - engine.decode(&s).unwrap_err(), - ); - } - } - } - } -} - -#[apply(all_engines)] -fn decode_bytes_after_padding_in_final_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // leave at least one byte in the quad for padding - for bytes_after_padding in 1..4 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - // every invalid padding position with a 3-byte final quad: 1 to 3 bytes after padding - for _ in 0..(3 - bytes_after_padding) { - s.push('A'); - } - s.push('='); - for _ in 0..bytes_after_padding { - s.push('A'); - } - - // First (and only) padding byte is invalid. - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + (3 - bytes_after_padding), - b'=' - ), - engine.decode(&s).unwrap_err() - ); - } - } - } -} - -#[apply(all_engines)] -fn decode_absurd_pad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("==Y=Wx===pY=2U====="); - - // first padding byte - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } -} - -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_too_much_padding_returns_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // add enough padding to ensure that we'll hit all decode stages at the different lengths - for pad_bytes in 1..=64 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - let padding: String = "=".repeat(pad_bytes); - s.push_str(&padding); - - if pad_bytes % 4 == 1 { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } else { - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } - } - } -} - -// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by -// decoding only the available complete quads -#[apply(all_engines_except_decoder_reader)] -fn decode_padding_followed_by_non_padding_returns_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - for pad_bytes in 0..=32 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - let padding: String = "=".repeat(pad_bytes); - s.push_str(&padding); - s.push('E'); - - if pad_bytes % 4 == 0 { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } else { - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4, b'='), - engine.decode(&s).unwrap_err() - ); - } - } - } - } -} - -#[apply(all_engines)] -fn decode_one_char_in_final_quad_with_padding_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("E="); - - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - - // more padding doesn't change the error - s.push('='); - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - - s.push('='); - assert_eq!( - DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='), - engine.decode(&s).unwrap_err() - ); - } - } -} - -#[apply(all_engines)] -fn decode_too_few_symbols_in_final_quad_error(engine_wrapper: E) { - for mode in all_pad_modes() { - // we don't encode so we don't care about encode padding - let engine = E::standard_with_pad_mode(true, mode); - - for num_prefix_quads in 0..256 { - // <2 is invalid - for final_quad_symbols in 0..2 { - for padding_symbols in 0..=(4 - final_quad_symbols) { - let mut s: String = "ABCD".repeat(num_prefix_quads); - - for _ in 0..final_quad_symbols { - s.push('A'); - } - for _ in 0..padding_symbols { - s.push('='); - } - - match final_quad_symbols + padding_symbols { - 0 => continue, - 1 => { - assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err()); - } - _ => { - // error reported at first padding byte - assert_eq!( - DecodeError::InvalidByte( - num_prefix_quads * 4 + final_quad_symbols, - b'=', - ), - engine.decode(&s).unwrap_err() - ); - } - } - } - } - } - } -} - +/// 1 trailing byte that's not padding is detected as invalid byte even though there's padding +/// in the middle of the input. This is essentially mandating the eager check for 1 trailing byte +/// to catch the \n suffix case. // DecoderReader pseudo-engine can't handle DecodePaddingMode::RequireNone since it will decode // a complete quad with padding in it before encountering the stray byte that makes it an invalid // length #[apply(all_engines_except_decoder_reader)] -fn decode_invalid_trailing_bytes(engine_wrapper: E) { +fn decode_invalid_trailing_bytes_all_pad_modes_invalid_byte(engine_wrapper: E) { for mode in all_pad_modes() { do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode); } } #[apply(all_engines)] -fn decode_invalid_trailing_bytes_all_modes(engine_wrapper: E) { +fn decode_invalid_trailing_bytes_invalid_byte(engine_wrapper: E) { // excluding no padding mode because the DecoderWrapper pseudo-engine will fail with // InvalidPadding because it will decode the last complete quad with padding first for mode in pad_modes_allowing_padding() { do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode); } } +fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) { + for last_byte in [b'*', b'\n'] { + for num_prefix_quads in 0..256 { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str("Cg=="); + let mut input = s.into_bytes(); + input.push(last_byte); + // The case of trailing newlines is common enough to warrant a test for a good error + // message. + assert_eq!( + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + 4, + last_byte + )), + engine.decode(&input), + "mode: {:?}, input: {}", + mode, + String::from_utf8(input).unwrap() + ); + } + } +} + +/// When there's 1 trailing byte, but it's padding, it's only InvalidByte if there isn't padding +/// earlier. #[apply(all_engines)] -fn decode_invalid_trailing_padding_as_invalid_length(engine_wrapper: E) { +fn decode_invalid_trailing_padding_as_invalid_byte_at_first_pad_byte( + engine_wrapper: E, +) { // excluding no padding mode because the DecoderWrapper pseudo-engine will fail with // InvalidPadding because it will decode the last complete quad with padding first for mode in pad_modes_allowing_padding() { - do_invalid_trailing_padding_as_invalid_length(E::standard_with_pad_mode(true, mode), mode); + do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + E::standard_with_pad_mode(true, mode), + mode, + ); } } @@ -1177,48 +1012,36 @@ fn decode_invalid_trailing_padding_as_invalid_length(engine_wr // a complete quad with padding in it before encountering the stray byte that makes it an invalid // length #[apply(all_engines_except_decoder_reader)] -fn decode_invalid_trailing_padding_as_invalid_length_all_modes( +fn decode_invalid_trailing_padding_as_invalid_byte_at_first_byte_all_modes( engine_wrapper: E, ) { for mode in all_pad_modes() { - do_invalid_trailing_padding_as_invalid_length(E::standard_with_pad_mode(true, mode), mode); + do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + E::standard_with_pad_mode(true, mode), + mode, + ); } } - -#[apply(all_engines)] -fn decode_wrong_length_error(engine_wrapper: E) { - let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent); - +fn do_invalid_trailing_padding_as_invalid_byte_at_first_padding( + engine: impl Engine, + mode: DecodePaddingMode, +) { for num_prefix_quads in 0..256 { - // at least one token, otherwise it wouldn't be a final quad - for num_tokens_final_quad in 1..=4 { - for num_padding in 0..=(4 - num_tokens_final_quad) { - let mut s: String = "IIII".repeat(num_prefix_quads); - for _ in 0..num_tokens_final_quad { - s.push('g'); - } - for _ in 0..num_padding { - s.push('='); - } + for (suffix, pad_offset) in [("AA===", 2), ("AAA==", 3), ("AAAA=", 4)] { + let mut s: String = "ABCD".repeat(num_prefix_quads); + s.push_str(suffix); - let res = engine.decode(&s); - if num_tokens_final_quad >= 2 { - assert!(res.is_ok()); - } else if num_tokens_final_quad == 1 && num_padding > 0 { - // = is invalid if it's too early - assert_eq!( - Err(DecodeError::InvalidByte( - num_prefix_quads * 4 + num_tokens_final_quad, - 61 - )), - res - ); - } else if num_padding > 2 { - assert_eq!(Err(DecodeError::InvalidPadding), res); - } else { - assert_eq!(Err(DecodeError::InvalidLength), res); - } - } + assert_eq!( + // pad after `g`, not the last one + Err(DecodeError::InvalidByte( + num_prefix_quads * 4 + pad_offset, + PAD_BYTE + )), + engine.decode(&s), + "mode: {:?}, input: {}", + mode, + s + ); } } } @@ -1248,12 +1071,20 @@ fn decode_into_slice_fits_in_precisely_sized_slice(engine_wrap assert_encode_sanity(&encoded_data, engine.config().encode_padding(), input_len); decode_buf.resize(input_len, 0); - // decode into the non-empty buf let decode_bytes_written = engine .decode_slice_unchecked(encoded_data.as_bytes(), &mut decode_buf[..]) .unwrap(); + assert_eq!(orig_data.len(), decode_bytes_written); + assert_eq!(orig_data, decode_buf); + // same for checked variant + decode_buf.clear(); + decode_buf.resize(input_len, 0); + // decode into the non-empty buf + let decode_bytes_written = engine + .decode_slice(encoded_data.as_bytes(), &mut decode_buf[..]) + .unwrap(); assert_eq!(orig_data.len(), decode_bytes_written); assert_eq!(orig_data, decode_buf); } @@ -1287,7 +1118,10 @@ fn inner_decode_reports_padding_position(engine_wrapper: E) { if pad_position % 4 < 2 { // impossible padding assert_eq!( - Err(DecodeError::InvalidByte(pad_position, PAD_BYTE)), + Err(DecodeSliceError::DecodeError(DecodeError::InvalidByte( + pad_position, + PAD_BYTE + ))), decode_res ); } else { @@ -1355,35 +1189,60 @@ fn estimate_via_u128_inflation(engine_wrapper: E) { }) } -fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) { - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("Cg==\n"); - - // The case of trailing newlines is common enough to warrant a test for a good error - // message. - assert_eq!( - Err(DecodeError::InvalidByte(num_prefix_quads * 4 + 4, b'\n')), - engine.decode(&s), - "mode: {:?}, input: {}", - mode, - s - ); - } -} +#[apply(all_engines)] +fn decode_slice_checked_fails_gracefully_at_all_output_lengths( + engine_wrapper: E, +) { + let mut rng = seeded_rng(); + for original_len in 0..1000 { + let mut original = vec![0; original_len]; + rng.fill(&mut original[..]); + + for mode in all_pad_modes() { + let engine = E::standard_with_pad_mode( + match mode { + DecodePaddingMode::Indifferent | DecodePaddingMode::RequireCanonical => true, + DecodePaddingMode::RequireNone => false, + }, + mode, + ); -fn do_invalid_trailing_padding_as_invalid_length(engine: impl Engine, mode: DecodePaddingMode) { - for num_prefix_quads in 0..256 { - let mut s: String = "ABCD".repeat(num_prefix_quads); - s.push_str("Cg==="); + let encoded = engine.encode(&original); + let mut decode_buf = Vec::with_capacity(original_len); + for decode_buf_len in 0..original_len { + decode_buf.resize(decode_buf_len, 0); + assert_eq!( + DecodeSliceError::OutputSliceTooSmall, + engine + .decode_slice(&encoded, &mut decode_buf[..]) + .unwrap_err(), + "original len: {}, encoded len: {}, buf len: {}, mode: {:?}", + original_len, + encoded.len(), + decode_buf_len, + mode + ); + // internal method works the same + assert_eq!( + DecodeSliceError::OutputSliceTooSmall, + engine + .internal_decode( + encoded.as_bytes(), + &mut decode_buf[..], + engine.internal_decoded_len_estimate(encoded.len()) + ) + .unwrap_err() + ); + } - assert_eq!( - Err(DecodeError::InvalidLength), - engine.decode(&s), - "mode: {:?}, input: {}", - mode, - s - ); + decode_buf.resize(original_len, 0); + rng.fill(&mut decode_buf[..]); + assert_eq!( + original_len, + engine.decode_slice(&encoded, &mut decode_buf[..]).unwrap() + ); + assert_eq!(original, decode_buf); + } } } @@ -1547,7 +1406,7 @@ impl EngineWrapper for NaiveWrapper { naive::Naive::new( &STANDARD, naive::NaiveConfig { - encode_padding: false, + encode_padding: encode_pad, decode_allow_trailing_bits: false, decode_padding_mode: decode_pad_mode, }, @@ -1616,7 +1475,7 @@ impl Engine for DecoderReaderEngine { input: &[u8], output: &mut [u8], decode_estimate: Self::DecodeEstimate, - ) -> Result { + ) -> Result { let mut reader = DecoderReader::new(input, &self.engine); let mut buf = vec![0; input.len()]; // to avoid effects like not detecting invalid length due to progressively growing @@ -1635,6 +1494,9 @@ impl Engine for DecoderReaderEngine { .and_then(|inner| inner.downcast::().ok()) .unwrap() })?; + if output.len() < buf.len() { + return Err(DecodeSliceError::OutputSliceTooSmall); + } output[..buf.len()].copy_from_slice(&buf); Ok(DecodeMetadata::new( buf.len(), diff --git a/src/lib.rs b/src/lib.rs index 6ec3c12..579a722 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -228,8 +228,7 @@ unused_extern_crates, unused_import_braces, unused_results, - variant_size_differences, - warnings + variant_size_differences )] #![forbid(unsafe_code)] // Allow globally until https://github.com/rust-lang/rust-clippy/issues/8768 is resolved. diff --git a/src/read/decoder.rs b/src/read/decoder.rs index b656ae3..781f6f8 100644 --- a/src/read/decoder.rs +++ b/src/read/decoder.rs @@ -1,4 +1,4 @@ -use crate::{engine::Engine, DecodeError, PAD_BYTE}; +use crate::{engine::Engine, DecodeError, DecodeSliceError, PAD_BYTE}; use std::{cmp, fmt, io}; // This should be large, but it has to fit on the stack. @@ -35,37 +35,39 @@ pub struct DecoderReader<'e, E: Engine, R: io::Read> { /// Where b64 data is read from inner: R, - // Holds b64 data read from the delegate reader. + /// Holds b64 data read from the delegate reader. b64_buffer: [u8; BUF_SIZE], - // The start of the pending buffered data in b64_buffer. + /// The start of the pending buffered data in `b64_buffer`. b64_offset: usize, - // The amount of buffered b64 data. + /// The amount of buffered b64 data after `b64_offset` in `b64_len`. b64_len: usize, - // Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a - // decoded chunk in to, we have to be able to hang on to a few decoded bytes. - // Technically we only need to hold 2 bytes but then we'd need a separate temporary buffer to - // decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest - // into here, which seems like a lot of complexity for 1 extra byte of storage. - decoded_buffer: [u8; DECODED_CHUNK_SIZE], - // index of start of decoded data + /// Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a + /// decoded chunk in to, we have to be able to hang on to a few decoded bytes. + /// Technically we only need to hold 2 bytes, but then we'd need a separate temporary buffer to + /// decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest + /// into here, which seems like a lot of complexity for 1 extra byte of storage. + decoded_chunk_buffer: [u8; DECODED_CHUNK_SIZE], + /// Index of start of decoded data in `decoded_chunk_buffer` decoded_offset: usize, - // length of decoded data + /// Length of decoded data after `decoded_offset` in `decoded_chunk_buffer` decoded_len: usize, - // used to provide accurate offsets in errors - total_b64_decoded: usize, - // offset of previously seen padding, if any + /// Input length consumed so far. + /// Used to provide accurate offsets in errors + input_consumed_len: usize, + /// offset of previously seen padding, if any padding_offset: Option, } +// exclude b64_buffer as it's uselessly large impl<'e, E: Engine, R: io::Read> fmt::Debug for DecoderReader<'e, E, R> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("DecoderReader") .field("b64_offset", &self.b64_offset) .field("b64_len", &self.b64_len) - .field("decoded_buffer", &self.decoded_buffer) + .field("decoded_chunk_buffer", &self.decoded_chunk_buffer) .field("decoded_offset", &self.decoded_offset) .field("decoded_len", &self.decoded_len) - .field("total_b64_decoded", &self.total_b64_decoded) + .field("input_consumed_len", &self.input_consumed_len) .field("padding_offset", &self.padding_offset) .finish() } @@ -80,10 +82,10 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { b64_buffer: [0; BUF_SIZE], b64_offset: 0, b64_len: 0, - decoded_buffer: [0; DECODED_CHUNK_SIZE], + decoded_chunk_buffer: [0; DECODED_CHUNK_SIZE], decoded_offset: 0, decoded_len: 0, - total_b64_decoded: 0, + input_consumed_len: 0, padding_offset: None, } } @@ -100,7 +102,7 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { debug_assert!(copy_len <= self.decoded_len); buf[..copy_len].copy_from_slice( - &self.decoded_buffer[self.decoded_offset..self.decoded_offset + copy_len], + &self.decoded_chunk_buffer[self.decoded_offset..self.decoded_offset + copy_len], ); self.decoded_offset += copy_len; @@ -131,6 +133,10 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { /// caller's responsibility to choose the number of b64 bytes to decode correctly. /// /// Returns a Result with the number of decoded bytes written to `buf`. + /// + /// # Panics + /// + /// panics if `buf` is too small fn decode_to_buf(&mut self, b64_len_to_decode: usize, buf: &mut [u8]) -> io::Result { debug_assert!(self.b64_len >= b64_len_to_decode); debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE); @@ -144,22 +150,35 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { buf, self.engine.internal_decoded_len_estimate(b64_len_to_decode), ) - .map_err(|e| match e { - DecodeError::InvalidByte(offset, byte) => { - // This can be incorrect, but not in a way that probably matters to anyone: - // if there was padding handled in a previous decode, and we are now getting - // InvalidByte due to more padding, we should arguably report InvalidByte with - // PAD_BYTE at the original padding position (`self.padding_offset`), but we - // don't have a good way to tie those two cases together, so instead we - // just report the invalid byte as if the previous padding, and its possibly - // related downgrade to a now invalid byte, didn't happen. - DecodeError::InvalidByte(self.total_b64_decoded + offset, byte) + .map_err(|dse| match dse { + DecodeSliceError::DecodeError(de) => { + match de { + DecodeError::InvalidByte(offset, byte) => { + match (byte, self.padding_offset) { + // if there was padding in a previous block of decoding that happened to + // be correct, and we now find more padding that happens to be incorrect, + // to be consistent with non-reader decodes, record the error at the first + // padding + (PAD_BYTE, Some(first_pad_offset)) => { + DecodeError::InvalidByte(first_pad_offset, PAD_BYTE) + } + _ => { + DecodeError::InvalidByte(self.input_consumed_len + offset, byte) + } + } + } + DecodeError::InvalidLength(len) => { + DecodeError::InvalidLength(self.input_consumed_len + len) + } + DecodeError::InvalidLastSymbol(offset, byte) => { + DecodeError::InvalidLastSymbol(self.input_consumed_len + offset, byte) + } + DecodeError::InvalidPadding => DecodeError::InvalidPadding, + } } - DecodeError::InvalidLength => DecodeError::InvalidLength, - DecodeError::InvalidLastSymbol(offset, byte) => { - DecodeError::InvalidLastSymbol(self.total_b64_decoded + offset, byte) + DecodeSliceError::OutputSliceTooSmall => { + unreachable!("buf is sized correctly in calling code") } - DecodeError::InvalidPadding => DecodeError::InvalidPadding, }) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; @@ -176,8 +195,8 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { self.padding_offset = self.padding_offset.or(decode_metadata .padding_offset - .map(|offset| self.total_b64_decoded + offset)); - self.total_b64_decoded += b64_len_to_decode; + .map(|offset| self.input_consumed_len + offset)); + self.input_consumed_len += b64_len_to_decode; self.b64_offset += b64_len_to_decode; self.b64_len -= b64_len_to_decode; @@ -283,7 +302,7 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> { let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE); let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?; - self.decoded_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]); + self.decoded_chunk_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]); self.decoded_offset = 0; self.decoded_len = decoded;