diff --git a/quinn-proto/src/connection/assembler.rs b/quinn-proto/src/connection/assembler.rs index 24c93cb55f..e43e39c581 100644 --- a/quinn-proto/src/connection/assembler.rs +++ b/quinn-proto/src/connection/assembler.rs @@ -13,11 +13,15 @@ use crate::range_set::RangeSet; pub(crate) struct Assembler { state: State, data: BinaryHeap, - defragmented: usize, + /// Total number of buffered bytes, including duplicates in ordered mode. + buffered: usize, + /// Estimated number of allocated bytes, will never be less than `buffered`. + allocated: usize, /// Number of bytes read by the application. When only ordered reads have been used, this is the /// length of the contiguous prefix of the stream which has been consumed by the application, /// aka the stream offset. bytes_read: u64, + end: u64, } impl Assembler { @@ -58,8 +62,9 @@ impl Assembler { return None; } else if (chunk.offset + chunk.bytes.len() as u64) <= self.bytes_read { // Next chunk is useless as the read index is beyond its end + self.buffered -= chunk.bytes.len(); + self.allocated -= chunk.allocation_size; PeekMut::pop(chunk); - self.defragmented = self.defragmented.saturating_sub(1); continue; } @@ -68,6 +73,7 @@ impl Assembler { if start > 0 { chunk.bytes.advance(start); chunk.offset += start as u64; + self.buffered -= start; } } @@ -75,61 +81,87 @@ impl Assembler { self.bytes_read += max_length as u64; let offset = chunk.offset; chunk.offset += max_length as u64; + self.buffered -= max_length; Chunk::new(offset, chunk.bytes.split_to(max_length)) } else { self.bytes_read += chunk.bytes.len() as u64; - self.defragmented = self.defragmented.saturating_sub(1); + self.buffered -= chunk.bytes.len(); + self.allocated -= chunk.allocation_size; let chunk = PeekMut::pop(chunk); Chunk::new(chunk.offset, chunk.bytes) }); } } - // Copy the buffered chunk data to new chunks backed by a single buffer to - // make sure we're not unnecessarily holding on to many larger allocations. - // Merge contiguous chunks in the process of doing so. Reset the `defragmented` - // counter to the new number of chunks left in the heap so that we can decide - // when to defragment the queue again if necessary. + /// Copy fragmented chunk data to new chunks backed by a single buffer + /// + /// This makes sure we're not unnecessarily holding on to many larger allocations. + /// We merge contiguous chunks in the process of doing so. fn defragment(&mut self) { - let buffered = self.data.iter().map(|c| c.bytes.len()).sum::(); - let mut buffer = BytesMut::with_capacity(buffered); - let mut offset = self - .data - .peek() - .as_ref() - .expect("defragment is only called when data is buffered") - .offset; - let new = BinaryHeap::with_capacity(self.data.len()); let old = mem::replace(&mut self.data, new); - for chunk in old.into_sorted_vec().into_iter().rev() { - let end = offset + (buffer.len() as u64); - if let Some(overlap) = end.checked_sub(chunk.offset) { - if let Some(bytes) = chunk.bytes.get(overlap as usize..) { - buffer.extend_from_slice(bytes); + let mut buffers = old.into_sorted_vec(); + self.buffered = 0; + let mut fragmented_buffered = 0; + let mut offset = 0; + for chunk in buffers.iter_mut().rev() { + chunk.try_mark_defragment(offset); + let size = chunk.bytes.len(); + offset = chunk.offset + size as u64; + self.buffered += size; + if !chunk.defragmented { + fragmented_buffered += size; + } + } + self.allocated = self.buffered; + let mut buffer = BytesMut::with_capacity(fragmented_buffered); + let mut offset = 0; + for chunk in buffers.into_iter().rev() { + if chunk.defragmented { + // bytes might be empty after prepare_defragment + if !chunk.bytes.is_empty() { + self.data.push(chunk); + } + continue; + } + // Overlap is resolved by try_mark_defragment + if chunk.offset != offset + (buffer.len() as u64) { + if !buffer.is_empty() { + self.data + .push(Buffer::new_defragmented(offset, buffer.split().freeze())); } - } else { - let bytes = buffer.split().freeze(); - self.data.push(Buffer { offset, bytes }); offset = chunk.offset; - buffer.extend_from_slice(&chunk.bytes); } + buffer.extend_from_slice(&chunk.bytes); + } + if !buffer.is_empty() { + self.data + .push(Buffer::new_defragmented(offset, buffer.split().freeze())); } - - let bytes = buffer.split().freeze(); - self.data.push(Buffer { offset, bytes }); - self.defragmented = self.data.len(); } - pub(crate) fn insert(&mut self, mut offset: u64, mut bytes: Bytes) { + // Note: If a packet contains many frames from the same stream, the estimated over-allocation + // will be much higher because we are counting the same allocation multiple times. + pub(crate) fn insert(&mut self, mut offset: u64, mut bytes: Bytes, allocation_size: usize) { + debug_assert!( + bytes.len() <= allocation_size, + "allocation_size less than bytes.len(): {:?} < {:?}", + allocation_size, + bytes.len() + ); + self.end = self.end.max(offset + bytes.len() as u64); if let State::Unordered { ref mut recvd } = self.state { // Discard duplicate data for duplicate in recvd.replace(offset..offset + bytes.len() as u64) { if duplicate.start > offset { - self.data.push(Buffer { + let buffer = Buffer::new( offset, - bytes: bytes.split_to((duplicate.start - offset) as usize), - }); + bytes.split_to((duplicate.start - offset) as usize), + allocation_size, + ); + self.buffered += buffer.bytes.len(); + self.allocated += buffer.allocation_size; + self.data.push(buffer); offset = duplicate.start; } bytes.advance((duplicate.end - offset) as usize); @@ -148,16 +180,25 @@ impl Assembler { if bytes.is_empty() { return; } - - self.data.push(Buffer { offset, bytes }); - // Why 32: on the one hand, we want to defragment rarely, ideally never + let buffer = Buffer::new(offset, bytes, allocation_size); + self.buffered += buffer.bytes.len(); + self.allocated += buffer.allocation_size; + self.data.push(buffer); + // `self.buffered` also counts duplicate bytes, therefore we use + // `self.end - self.bytes_read` as an upper bound of buffered unique + // bytes. This will cause a defragmentation if the amount of duplicate + // bytes exceedes a proportion of the receive window size. + let buffered = self.buffered.min((self.end - self.bytes_read) as usize); + let over_allocation = self.allocated - buffered; + // Rationale: on the one hand, we want to defragment rarely, ideally never // in non-pathological scenarios. However, a pathological or malicious // peer could send us one-byte frames, and since we use reference-counted // buffers in order to prevent copying, this could result in keeping a lot - // of memory allocated. In the worst case scenario of 32 1-byte chunks, - // each one from a ~1000-byte datagram, using 32 limits us to having a - // maximum pathological over-allocation of about 32k bytes. - if self.data.len() - self.defragmented > 32 { + // of memory allocated. This limits over-allocation in proportion to the + // buffered data. The constants are chosen somewhat arbitrarily and try to + // balance between defragmentation overhead and over-allocation. + let threshold = 32768.max(buffered * 3 / 2); + if over_allocation > threshold { self.defragment() } } @@ -174,7 +215,8 @@ impl Assembler { /// Discard all buffered data pub(crate) fn clear(&mut self) { self.data.clear(); - self.defragmented = 0; + self.buffered = 0; + self.allocated = 0; } } @@ -197,6 +239,50 @@ impl Chunk { struct Buffer { offset: u64, bytes: Bytes, + allocation_size: usize, + defragmented: bool, +} + +impl Buffer { + /// Constructs a new fragmented Buffer + fn new(offset: u64, bytes: Bytes, allocation_size: usize) -> Self { + Self { + offset, + bytes, + allocation_size, + defragmented: false, + } + } + + /// Constructs a new defragmented Buffer + fn new_defragmented(offset: u64, bytes: Bytes) -> Self { + let allocation_size = bytes.len(); + Self { + offset, + bytes, + allocation_size, + defragmented: true, + } + } + + /// Discards data before `offset` and flags `self` as defragmented if it has good utilization + fn try_mark_defragment(&mut self, offset: u64) { + let duplicate = offset.saturating_sub(self.offset) as usize; + self.offset = self.offset.max(offset); + if duplicate >= self.bytes.len() { + self.bytes = Bytes::new(); + self.defragmented = true; + self.allocation_size = 0; + return; + } else { + self.bytes.advance(duplicate); + } + self.defragmented = self.defragmented || self.bytes.len() * 6 / 5 >= self.allocation_size; + if self.defragmented { + // Make sure that defragmented buffers do not contribute to over-allocation + self.allocation_size = self.bytes.len(); + } + } } impl Ord for Buffer { @@ -257,13 +343,13 @@ mod test { fn assemble_ordered() { let mut x = Assembler::new(); assert_matches!(next(&mut x, 32), None); - x.insert(0, Bytes::from_static(b"123")); + x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 1), Some(ref y) if &y[..] == b"1"); assert_matches!(next(&mut x, 3), Some(ref y) if &y[..] == b"23"); - x.insert(3, Bytes::from_static(b"456")); + x.insert(3, Bytes::from_static(b"456"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456"); - x.insert(6, Bytes::from_static(b"789")); - x.insert(9, Bytes::from_static(b"10")); + x.insert(6, Bytes::from_static(b"789"), 3); + x.insert(9, Bytes::from_static(b"10"), 2); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"789"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"10"); assert_matches!(next(&mut x, 32), None); @@ -273,9 +359,9 @@ mod test { fn assemble_unordered() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); - x.insert(3, Bytes::from_static(b"456")); + x.insert(3, Bytes::from_static(b"456"), 3); assert_matches!(next(&mut x, 32), None); - x.insert(0, Bytes::from_static(b"123")); + x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"456"); assert_matches!(next(&mut x, 32), None); @@ -284,8 +370,8 @@ mod test { #[test] fn assemble_duplicate() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"123")); - x.insert(0, Bytes::from_static(b"123")); + x.insert(0, Bytes::from_static(b"123"), 3); + x.insert(0, Bytes::from_static(b"123"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), None); } @@ -293,8 +379,8 @@ mod test { #[test] fn assemble_duplicate_compact() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"123")); - x.insert(0, Bytes::from_static(b"123")); + x.insert(0, Bytes::from_static(b"123"), 3); + x.insert(0, Bytes::from_static(b"123"), 3); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), None); @@ -303,8 +389,8 @@ mod test { #[test] fn assemble_contained() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"12345")); - x.insert(1, Bytes::from_static(b"234")); + x.insert(0, Bytes::from_static(b"12345"), 5); + x.insert(1, Bytes::from_static(b"234"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } @@ -312,8 +398,8 @@ mod test { #[test] fn assemble_contained_compact() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"12345")); - x.insert(1, Bytes::from_static(b"234")); + x.insert(0, Bytes::from_static(b"12345"), 5); + x.insert(1, Bytes::from_static(b"234"), 3); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); @@ -322,8 +408,8 @@ mod test { #[test] fn assemble_contains() { let mut x = Assembler::new(); - x.insert(1, Bytes::from_static(b"234")); - x.insert(0, Bytes::from_static(b"12345")); + x.insert(1, Bytes::from_static(b"234"), 3); + x.insert(0, Bytes::from_static(b"12345"), 5); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); } @@ -331,8 +417,8 @@ mod test { #[test] fn assemble_contains_compact() { let mut x = Assembler::new(); - x.insert(1, Bytes::from_static(b"234")); - x.insert(0, Bytes::from_static(b"12345")); + x.insert(1, Bytes::from_static(b"234"), 3); + x.insert(0, Bytes::from_static(b"12345"), 5); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"12345"); assert_matches!(next(&mut x, 32), None); @@ -341,8 +427,8 @@ mod test { #[test] fn assemble_overlapping() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"123")); - x.insert(1, Bytes::from_static(b"234")); + x.insert(0, Bytes::from_static(b"123"), 3); + x.insert(1, Bytes::from_static(b"234"), 3); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123"); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"4"); assert_matches!(next(&mut x, 32), None); @@ -351,8 +437,8 @@ mod test { #[test] fn assemble_overlapping_compact() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"123")); - x.insert(1, Bytes::from_static(b"234")); + x.insert(0, Bytes::from_static(b"123"), 4); + x.insert(1, Bytes::from_static(b"234"), 4); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234"); assert_matches!(next(&mut x, 32), None); @@ -361,10 +447,10 @@ mod test { #[test] fn assemble_complex() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"1")); - x.insert(2, Bytes::from_static(b"3")); - x.insert(4, Bytes::from_static(b"5")); - x.insert(0, Bytes::from_static(b"123456")); + x.insert(0, Bytes::from_static(b"1"), 1); + x.insert(2, Bytes::from_static(b"3"), 1); + x.insert(4, Bytes::from_static(b"5"), 1); + x.insert(0, Bytes::from_static(b"123456"), 6); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456"); assert_matches!(next(&mut x, 32), None); } @@ -372,10 +458,10 @@ mod test { #[test] fn assemble_complex_compact() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"1")); - x.insert(2, Bytes::from_static(b"3")); - x.insert(4, Bytes::from_static(b"5")); - x.insert(0, Bytes::from_static(b"123456")); + x.insert(0, Bytes::from_static(b"1"), 1); + x.insert(2, Bytes::from_static(b"3"), 1); + x.insert(4, Bytes::from_static(b"5"), 1); + x.insert(0, Bytes::from_static(b"123456"), 6); x.defragment(); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"123456"); assert_matches!(next(&mut x, 32), None); @@ -384,19 +470,19 @@ mod test { #[test] fn assemble_old() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"1234")); + x.insert(0, Bytes::from_static(b"1234"), 4); assert_matches!(next(&mut x, 32), Some(ref y) if &y[..] == b"1234"); - x.insert(0, Bytes::from_static(b"1234")); + x.insert(0, Bytes::from_static(b"1234"), 4); assert_matches!(next(&mut x, 32), None); } #[test] fn compact() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"abc")); - x.insert(3, Bytes::from_static(b"def")); - x.insert(9, Bytes::from_static(b"jkl")); - x.insert(12, Bytes::from_static(b"mno")); + x.insert(0, Bytes::from_static(b"abc"), 4); + x.insert(3, Bytes::from_static(b"def"), 4); + x.insert(9, Bytes::from_static(b"jkl"), 4); + x.insert(12, Bytes::from_static(b"mno"), 4); x.defragment(); assert_eq!( next_unordered(&mut x), @@ -411,7 +497,7 @@ mod test { #[test] fn defrag_with_missing_prefix() { let mut x = Assembler::new(); - x.insert(3, Bytes::from_static(b"def")); + x.insert(3, Bytes::from_static(b"def"), 3); x.defragment(); assert_eq!( next_unordered(&mut x), @@ -422,17 +508,17 @@ mod test { #[test] fn defrag_read_chunk() { let mut x = Assembler::new(); - x.insert(3, Bytes::from_static(b"def")); - x.insert(0, Bytes::from_static(b"abc")); - x.insert(7, Bytes::from_static(b"hij")); - x.insert(11, Bytes::from_static(b"lmn")); + x.insert(3, Bytes::from_static(b"def"), 4); + x.insert(0, Bytes::from_static(b"abc"), 4); + x.insert(7, Bytes::from_static(b"hij"), 4); + x.insert(11, Bytes::from_static(b"lmn"), 4); x.defragment(); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"abcdef"); - x.insert(5, Bytes::from_static(b"fghijklmn")); + x.insert(5, Bytes::from_static(b"fghijklmn"), 9); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"ghijklmn"); - x.insert(13, Bytes::from_static(b"nopq")); + x.insert(13, Bytes::from_static(b"nopq"), 4); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"opq"); - x.insert(15, Bytes::from_static(b"pqrs")); + x.insert(15, Bytes::from_static(b"pqrs"), 4); assert_matches!(x.read(usize::MAX, true), Some(ref y) if &y.bytes[..] == b"rs"); assert_matches!(x.read(usize::MAX, true), None); } @@ -441,13 +527,13 @@ mod test { fn unordered_happy_path() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); - x.insert(0, Bytes::from_static(b"abc")); + x.insert(0, Bytes::from_static(b"abc"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(0, Bytes::from_static(b"abc")) ); assert_eq!(x.read(usize::MAX, false), None); - x.insert(3, Bytes::from_static(b"def")); + x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) @@ -459,15 +545,15 @@ mod test { fn unordered_dedup() { let mut x = Assembler::new(); x.ensure_ordering(false).unwrap(); - x.insert(3, Bytes::from_static(b"def")); + x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(3, Bytes::from_static(b"def")) ); assert_eq!(x.read(usize::MAX, false), None); - x.insert(0, Bytes::from_static(b"a")); - x.insert(0, Bytes::from_static(b"abcdefghi")); - x.insert(0, Bytes::from_static(b"abcd")); + x.insert(0, Bytes::from_static(b"a"), 1); + x.insert(0, Bytes::from_static(b"abcdefghi"), 9); + x.insert(0, Bytes::from_static(b"abcd"), 4); assert_eq!( next_unordered(&mut x), Chunk::new(0, Bytes::from_static(b"a")) @@ -481,30 +567,30 @@ mod test { Chunk::new(6, Bytes::from_static(b"ghi")) ); assert_eq!(x.read(usize::MAX, false), None); - x.insert(8, Bytes::from_static(b"ijkl")); + x.insert(8, Bytes::from_static(b"ijkl"), 4); assert_eq!( next_unordered(&mut x), Chunk::new(9, Bytes::from_static(b"jkl")) ); assert_eq!(x.read(usize::MAX, false), None); - x.insert(12, Bytes::from_static(b"mno")); + x.insert(12, Bytes::from_static(b"mno"), 3); assert_eq!( next_unordered(&mut x), Chunk::new(12, Bytes::from_static(b"mno")) ); assert_eq!(x.read(usize::MAX, false), None); - x.insert(2, Bytes::from_static(b"cde")); + x.insert(2, Bytes::from_static(b"cde"), 3); assert_eq!(x.read(usize::MAX, false), None); } #[test] fn chunks_dedup() { let mut x = Assembler::new(); - x.insert(3, Bytes::from_static(b"def")); + x.insert(3, Bytes::from_static(b"def"), 3); assert_eq!(x.read(usize::MAX, true), None); - x.insert(0, Bytes::from_static(b"a")); - x.insert(1, Bytes::from_static(b"bcdefghi")); - x.insert(0, Bytes::from_static(b"abcd")); + x.insert(0, Bytes::from_static(b"a"), 1); + x.insert(1, Bytes::from_static(b"bcdefghi"), 9); + x.insert(0, Bytes::from_static(b"abcd"), 4); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(0, Bytes::from_static(b"abcd"))) @@ -514,48 +600,45 @@ mod test { Some(Chunk::new(4, Bytes::from_static(b"efghi"))) ); assert_eq!(x.read(usize::MAX, true), None); - x.insert(8, Bytes::from_static(b"ijkl")); + x.insert(8, Bytes::from_static(b"ijkl"), 4); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(9, Bytes::from_static(b"jkl"))) ); assert_eq!(x.read(usize::MAX, true), None); - x.insert(12, Bytes::from_static(b"mno")); + x.insert(12, Bytes::from_static(b"mno"), 3); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(12, Bytes::from_static(b"mno"))) ); assert_eq!(x.read(usize::MAX, true), None); - x.insert(2, Bytes::from_static(b"cde")); + x.insert(2, Bytes::from_static(b"cde"), 3); assert_eq!(x.read(usize::MAX, true), None); } #[test] fn ordered_eager_discard() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"abc")); + x.insert(0, Bytes::from_static(b"abc"), 3); assert_eq!(x.data.len(), 1); assert_eq!( x.read(usize::MAX, true), Some(Chunk::new(0, Bytes::from_static(b"abc"))) ); - x.insert(0, Bytes::from_static(b"ab")); + x.insert(0, Bytes::from_static(b"ab"), 2); assert_eq!(x.data.len(), 0); - x.insert(2, Bytes::from_static(b"cd")); + x.insert(2, Bytes::from_static(b"cd"), 2); assert_eq!( x.data.peek(), - Some(&Buffer { - offset: 3, - bytes: Bytes::from_static(b"d") - }) + Some(&Buffer::new(3, Bytes::from_static(b"d"), 2)) ); } #[test] fn ordered_insert_unordered_read() { let mut x = Assembler::new(); - x.insert(0, Bytes::from_static(b"abc")); - x.insert(0, Bytes::from_static(b"abc")); + x.insert(0, Bytes::from_static(b"abc"), 3); + x.insert(0, Bytes::from_static(b"abc"), 3); x.ensure_ordering(false).unwrap(); assert_eq!( x.read(3, false), diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 84c6fa783e..f3adac313a 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -1762,6 +1762,7 @@ where &mut self, space: SpaceId, crypto: &frame::Crypto, + payload_len: usize, ) -> Result<(), TransportError> { let expected = if !self.state.is_handshake() { SpaceId::Data @@ -1796,7 +1797,7 @@ where space .crypto_stream - .insert(crypto.offset, crypto.data.clone()); + .insert(crypto.offset, crypto.data.clone(), payload_len); while let Some(chunk) = space.crypto_stream.read(usize::MAX, true) { trace!("consumed {} CRYPTO bytes", chunk.bytes.len()); if self.crypto.read_handshake(&chunk.bytes)? { @@ -2331,6 +2332,7 @@ where packet: Packet, ) -> Result<(), TransportError> { debug_assert_ne!(packet.header.space(), SpaceId::Data); + let payload_len = packet.payload.len(); for frame in frame::Iter::new(packet.payload.freeze()) { let span = match frame { Frame::Padding => continue, @@ -2351,7 +2353,7 @@ where match frame { Frame::Padding | Frame::Ping => {} Frame::Crypto(frame) => { - self.read_crypto(packet.header.space(), &frame)?; + self.read_crypto(packet.header.space(), &frame, payload_len)?; } Frame::Ack(ack) => { self.on_ack_received(now, packet.header.space(), ack)?; @@ -2389,6 +2391,7 @@ where let is_0rtt = self.spaces[SpaceId::Data].crypto.is_none(); let mut is_probing_packet = true; let mut close = None; + let payload_len = payload.len(); for frame in frame::Iter::new(payload) { let span = match frame { Frame::Padding => continue, @@ -2433,10 +2436,10 @@ where return Err(err); } Frame::Crypto(frame) => { - self.read_crypto(SpaceId::Data, &frame)?; + self.read_crypto(SpaceId::Data, &frame, payload_len)?; } Frame::Stream(frame) => { - if self.streams.received(frame)?.should_transmit() { + if self.streams.received(frame, payload_len)?.should_transmit() { self.spaces[SpaceId::Data].pending.max_data = true; } } diff --git a/quinn-proto/src/connection/streams.rs b/quinn-proto/src/connection/streams.rs index 4e6a8c577b..a2dfc33656 100644 --- a/quinn-proto/src/connection/streams.rs +++ b/quinn-proto/src/connection/streams.rs @@ -281,7 +281,11 @@ impl Streams { /// Process incoming stream frame /// /// If successful, returns whether a `MAX_DATA` frame needs to be transmitted - pub fn received(&mut self, frame: frame::Stream) -> Result { + pub fn received( + &mut self, + frame: frame::Stream, + payload_len: usize, + ) -> Result { trace!(id = %frame.id, offset = frame.offset, len = frame.data.len(), fin = frame.fin, "got stream"); let stream = frame.id; self.validate_receive_id(stream).map_err(|e| { @@ -302,7 +306,7 @@ impl Streams { return Ok(ShouldTransmit(false)); } - let new_bytes = rs.ingest(frame, self.data_recvd, self.local_max_data)?; + let new_bytes = rs.ingest(frame, payload_len, self.data_recvd, self.local_max_data)?; self.data_recvd = self.data_recvd.saturating_add(new_bytes); if !rs.stopped { @@ -1084,12 +1088,15 @@ mod tests { let initial_max = client.local_max_data; assert_eq!( client - .received(frame::Stream { - id, - offset: 0, - fin: false, - data: Bytes::from_static(&[0; 2048]), - }) + .received( + frame::Stream { + id, + offset: 0, + fin: false, + data: Bytes::from_static(&[0; 2048]), + }, + 2048 + ) .unwrap(), ShouldTransmit(false) ); @@ -1118,12 +1125,15 @@ mod tests { let initial_max = client.local_max_data; assert_eq!( client - .received(frame::Stream { - id, - offset: 4096, - fin: false, - data: Bytes::from_static(&[0; 0]), - }) + .received( + frame::Stream { + id, + offset: 4096, + fin: false, + data: Bytes::from_static(&[0; 0]), + }, + 0 + ) .unwrap(), ShouldTransmit(false) ); @@ -1178,12 +1188,15 @@ mod tests { let initial_max = client.local_max_data; assert_eq!( client - .received(frame::Stream { - id, - offset: 0, - fin: false, - data: Bytes::from_static(&[0; 32]), - }) + .received( + frame::Stream { + id, + offset: 0, + fin: false, + data: Bytes::from_static(&[0; 32]), + }, + 32 + ) .unwrap(), ShouldTransmit(false) ); @@ -1204,12 +1217,15 @@ mod tests { assert_eq!(client.local_max_data - initial_max, 32); assert_eq!( client - .received(frame::Stream { - id, - offset: 32, - fin: true, - data: Bytes::from_static(&[0; 16]), - }) + .received( + frame::Stream { + id, + offset: 32, + fin: true, + data: Bytes::from_static(&[0; 16]), + }, + 16 + ) .unwrap(), ShouldTransmit(false) ); @@ -1224,12 +1240,15 @@ mod tests { // Server opens stream assert_eq!( client - .received(frame::Stream { - id, - offset: 0, - fin: false, - data: Bytes::from_static(&[0; 32]), - }) + .received( + frame::Stream { + id, + offset: 0, + fin: false, + data: Bytes::from_static(&[0; 32]) + }, + 32 + ) .unwrap(), ShouldTransmit(false) ); diff --git a/quinn-proto/src/connection/streams/recv.rs b/quinn-proto/src/connection/streams/recv.rs index 7c305e7c4a..bcbeeec76a 100644 --- a/quinn-proto/src/connection/streams/recv.rs +++ b/quinn-proto/src/connection/streams/recv.rs @@ -29,6 +29,7 @@ impl Recv { pub(super) fn ingest( &mut self, frame: frame::Stream, + payload_len: usize, received: u64, max_data: u64, ) -> Result { @@ -60,7 +61,7 @@ impl Recv { self.end = self.end.max(end); if !self.stopped { - self.assembler.insert(frame.offset, frame.data); + self.assembler.insert(frame.offset, frame.data, payload_len); } else { self.assembler.set_bytes_read(end); }