From 463448e17702b43d6a3633b8f10c4ad21d4f1a66 Mon Sep 17 00:00:00 2001 From: Cameron Bytheway Date: Thu, 8 Jul 2021 11:34:58 -0700 Subject: [PATCH] platform: add GSO support --- quic/s2n-quic-core/src/io/tx.rs | 7 + quic/s2n-quic-platform/src/buffer.rs | 8 +- quic/s2n-quic-platform/src/buffer/segment.rs | 59 ------ quic/s2n-quic-platform/src/buffer/vec.rs | 34 ++-- quic/s2n-quic-platform/src/features.rs | 20 ++ quic/s2n-quic-platform/src/features/gso.rs | 83 ++++++++ quic/s2n-quic-platform/src/io/tokio.rs | 7 +- quic/s2n-quic-platform/src/lib.rs | 1 + quic/s2n-quic-platform/src/message.rs | 23 ++- quic/s2n-quic-platform/src/message/macros.rs | 16 +- quic/s2n-quic-platform/src/message/mmsg.rs | 93 +++++---- quic/s2n-quic-platform/src/message/msg.rs | 178 +++++++++++++++--- quic/s2n-quic-platform/src/message/queue.rs | 21 ++- .../src/message/queue/behavior.rs | 6 +- .../src/message/queue/slice.rs | 162 ++++++++++++++-- quic/s2n-quic-platform/src/message/simple.rs | 67 +++++-- quic/s2n-quic-platform/src/socket/mmsg.rs | 32 +++- quic/s2n-quic-platform/src/socket/msg.rs | 35 +++- quic/s2n-quic-platform/src/socket/std.rs | 40 ++-- .../src/connection/close_sender.rs | 10 + .../src/connection/transmission.rs | 9 + quic/s2n-quic-transport/src/endpoint/retry.rs | 10 + .../src/endpoint/stateless_reset.rs | 10 + .../src/endpoint/version.rs | 10 + .../src/sync/data_sender.rs | 1 + .../src/sync/data_sender/writer.rs | 2 + .../src/transmission/context.rs | 24 +++ 27 files changed, 771 insertions(+), 197 deletions(-) delete mode 100644 quic/s2n-quic-platform/src/buffer/segment.rs create mode 100644 quic/s2n-quic-platform/src/features.rs create mode 100644 quic/s2n-quic-platform/src/features/gso.rs diff --git a/quic/s2n-quic-core/src/io/tx.rs b/quic/s2n-quic-core/src/io/tx.rs index c1914ad354..4a274eb071 100644 --- a/quic/s2n-quic-core/src/io/tx.rs +++ b/quic/s2n-quic-core/src/io/tx.rs @@ -85,6 +85,9 @@ pub trait Message { /// Returns the IPv6 flow label for the message fn ipv6_flow_label(&mut self) -> u32; + /// Returns true if the packet can be used in a GSO packet + fn can_gso(&self) -> bool; + /// Writes the payload of the message to an output buffer fn write_payload(&mut self, buffer: &mut [u8]) -> usize; } @@ -106,6 +109,10 @@ impl> Message for (SocketAddress, Payload) { 0 } + fn can_gso(&self) -> bool { + true + } + fn write_payload(&mut self, buffer: &mut [u8]) -> usize { let payload = self.1.as_ref(); let len = payload.len(); diff --git a/quic/s2n-quic-platform/src/buffer.rs b/quic/s2n-quic-platform/src/buffer.rs index 8837c43ac4..5cb0a6f1c1 100644 --- a/quic/s2n-quic-platform/src/buffer.rs +++ b/quic/s2n-quic-platform/src/buffer.rs @@ -1,16 +1,14 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -mod segment; +use core::ops::{Deref, DerefMut}; + mod vec; // TODO support mmap buffers -pub use segment::*; pub use vec::*; -use core::ops::{Index, IndexMut}; - -pub trait Buffer: Index + IndexMut { +pub trait Buffer: Deref + DerefMut { fn len(&self) -> usize; fn is_empty(&self) -> bool { diff --git a/quic/s2n-quic-platform/src/buffer/segment.rs b/quic/s2n-quic-platform/src/buffer/segment.rs deleted file mode 100644 index d2315ff6f2..0000000000 --- a/quic/s2n-quic-platform/src/buffer/segment.rs +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -use crate::buffer::Buffer; -use core::ops::{Deref, DerefMut, Index, IndexMut, Range}; -use s2n_quic_core::path::MINIMUM_MTU; - -#[derive(Debug)] -pub struct SegmentBuffer { - region: Region, - mtu: usize, -} - -impl SegmentBuffer { - pub fn new(region: Region, mtu: usize) -> Self { - assert!( - mtu >= (MINIMUM_MTU as usize), - "MTU must be at least {} for spec compatibility", - MINIMUM_MTU - ); - Self { region, mtu } - } - - const fn byte_range(&self, index: usize) -> Range { - let start = index * self.mtu; - let end = start + self.mtu; - start..end - } -} - -impl + DerefMut> Buffer for SegmentBuffer { - fn len(&self) -> usize { - self.region.len() / self.mtu - } - - fn is_empty(&self) -> bool { - self.region.is_empty() - } - - fn mtu(&self) -> usize { - self.mtu - } -} - -impl> Index for SegmentBuffer { - type Output = [u8]; - - fn index(&self, index: usize) -> &Self::Output { - let range = self.byte_range(index); - &self.region[range] - } -} - -impl> IndexMut for SegmentBuffer { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - let range = self.byte_range(index); - &mut self.region[range] - } -} diff --git a/quic/s2n-quic-platform/src/buffer/vec.rs b/quic/s2n-quic-platform/src/buffer/vec.rs index 2b06730e9d..3e275fcbc3 100644 --- a/quic/s2n-quic-platform/src/buffer/vec.rs +++ b/quic/s2n-quic-platform/src/buffer/vec.rs @@ -1,24 +1,28 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::buffer::{Buffer, SegmentBuffer}; +use crate::buffer::Buffer; use core::{ fmt, - ops::{Index, IndexMut}, + ops::{Deref, DerefMut}, }; use s2n_quic_core::path::DEFAULT_MAX_MTU; // TODO decide on better defaults -const DEFAULT_MESSAGE_COUNT: usize = 4096; +const DEFAULT_MESSAGE_COUNT: usize = 1024; -pub struct VecBuffer(SegmentBuffer>); +pub struct VecBuffer { + region: alloc::vec::Vec, + mtu: usize, +} impl VecBuffer { /// Create a contiguous buffer with the specified number of messages pub fn new(message_count: usize, mtu: usize) -> Self { let len = message_count * mtu; - let vec = alloc::vec![0; len]; - Self(SegmentBuffer::new(vec, mtu)) + let region = alloc::vec![0; len]; + + Self { region, mtu } } } @@ -45,24 +49,24 @@ impl fmt::Debug for VecBuffer { impl Buffer for VecBuffer { fn len(&self) -> usize { - self.0.len() + self.region.len() } fn mtu(&self) -> usize { - self.0.mtu() + self.mtu } } -impl Index for VecBuffer { - type Output = [u8]; +impl Deref for VecBuffer { + type Target = [u8]; - fn index(&self, index: usize) -> &Self::Output { - self.0.index(index) + fn deref(&self) -> &[u8] { + self.region.as_ref() } } -impl IndexMut for VecBuffer { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - self.0.index_mut(index) +impl DerefMut for VecBuffer { + fn deref_mut(&mut self) -> &mut [u8] { + self.region.as_mut() } } diff --git a/quic/s2n-quic-platform/src/features.rs b/quic/s2n-quic-platform/src/features.rs new file mode 100644 index 0000000000..7acd46894e --- /dev/null +++ b/quic/s2n-quic-platform/src/features.rs @@ -0,0 +1,20 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use lazy_static::lazy_static; + +mod gso; +pub use gso::Gso; + +lazy_static! { + static ref FEATURES: Features = Features::default(); +} + +pub fn get() -> &'static Features { + &*FEATURES +} + +#[derive(Debug, Default)] +pub struct Features { + pub gso: Gso, +} diff --git a/quic/s2n-quic-platform/src/features/gso.rs b/quic/s2n-quic-platform/src/features/gso.rs new file mode 100644 index 0000000000..5553863fee --- /dev/null +++ b/quic/s2n-quic-platform/src/features/gso.rs @@ -0,0 +1,83 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use core::num::NonZeroUsize; + +#[derive(Debug)] +pub struct Gso { + max_segments: NonZeroUsize, +} + +impl Default for Gso { + fn default() -> Self { + let max_segments = if cfg!(target_os = "linux") { + // https://github.com/torvalds/linux/blob/e9f1cbc0c4114880090c7a578117d3b9cf184ad4/tools/testing/selftests/net/udpgso.c#L37 + // ``` + // #define UDP_MAX_SEGMENTS (1 << 6UL) + // ``` + 1 << 6 + } else { + 1 + }; + + let max_segments = NonZeroUsize::new(max_segments).unwrap(); + + Self { max_segments } + } +} + +#[cfg(target_os = "linux")] +impl Gso { + pub fn set(cmsg: &mut [u8], segment_size: usize) -> usize { + use core::mem::size_of; + type SegmentType = u16; + + let len = unsafe { libc::CMSG_SPACE(size_of::() as _) } as usize; + debug_assert_ne!(len, 0); + assert!( + cmsg.len() >= len, + "out of space in cmsg: needed {}, got {}", + len, + cmsg.len() + ); + + let cmsg = unsafe { &mut *(&mut cmsg[0] as *mut u8 as *mut libc::cmsghdr) }; + cmsg.cmsg_level = libc::SOL_UDP; + cmsg.cmsg_type = libc::UDP_SEGMENT; + cmsg.cmsg_len = unsafe { libc::CMSG_LEN(size_of::() as _) } as _; + unsafe { + core::ptr::write( + libc::CMSG_DATA(cmsg) as *const _ as *mut _, + segment_size as SegmentType, + ); + } + + len + } + + pub fn max_segments(&self) -> usize { + self.max_segments.get() + } + + pub fn default_max_segments(&self) -> usize { + // TODO profile a good default + const DEFAULT_MAX_SEGMENTS: usize = 16; + + self.max_segments().min(DEFAULT_MAX_SEGMENTS) + } +} + +#[cfg(not(target_os = "linux"))] +impl Gso { + pub fn set(_cmsg: &mut [u8], _segment_size: usize) -> usize { + panic!("cannot use GSO on the current platform") + } + + pub fn max_segments(&self) -> usize { + 1 + } + + pub fn default_max_segments(&self) -> usize { + 1 + } +} diff --git a/quic/s2n-quic-platform/src/io/tokio.rs b/quic/s2n-quic-platform/src/io/tokio.rs index fac8e07f14..ff6d1df7fc 100644 --- a/quic/s2n-quic-platform/src/io/tokio.rs +++ b/quic/s2n-quic-platform/src/io/tokio.rs @@ -664,8 +664,11 @@ mod tests { let len = entries.len(); for entry in entries { let payload: &[u8] = entry.payload_mut(); - let payload = payload.try_into().unwrap(); - let id = u32::from_be_bytes(payload); + if payload.len() != 4 { + panic!("invalid payload {:?}", payload); + } + let id = payload.try_into().unwrap(); + let id = u32::from_be_bytes(id); self.messages.remove(&id); } queue.finish(len); diff --git a/quic/s2n-quic-platform/src/lib.rs b/quic/s2n-quic-platform/src/lib.rs index 1319659af5..b24687f4c5 100644 --- a/quic/s2n-quic-platform/src/lib.rs +++ b/quic/s2n-quic-platform/src/lib.rs @@ -12,6 +12,7 @@ extern crate alloc; mod macros; pub mod buffer; +pub mod features; pub mod io; pub mod message; pub mod socket; diff --git a/quic/s2n-quic-platform/src/message.rs b/quic/s2n-quic-platform/src/message.rs index 11ae8d80a9..324fc40191 100644 --- a/quic/s2n-quic-platform/src/message.rs +++ b/quic/s2n-quic-platform/src/message.rs @@ -18,6 +18,8 @@ use s2n_quic_core::inet::{ExplicitCongestionNotification, SocketAddress}; /// An abstract message that can be sent and received on a network pub trait Message { + const SUPPORTS_GSO: bool; + /// Returns the ECN values for the message fn ecn(&self) -> ExplicitCongestionNotification; @@ -30,9 +32,6 @@ pub trait Message { /// Sets the `SocketAddress` for the message fn set_remote_address(&mut self, remote_address: &SocketAddress); - /// Resets the `SocketAddress` for the message - fn reset_remote_address(&mut self); - /// Returns the length of the payload fn payload_len(&self) -> usize; @@ -65,6 +64,15 @@ pub trait Message { unsafe { core::slice::from_raw_parts_mut(self.payload_ptr_mut(), self.payload_len()) } } + /// Sets the segment size for the message payload + fn set_segment_size(&mut self, size: usize); + + /// Resets the message for future use + /// + /// # Safety + /// This method should only set the MTU to the original value + unsafe fn reset(&mut self, mtu: usize); + /// Returns a pointer to the Message fn as_ptr(&self) -> *const c_void { self as *const _ as *const _ @@ -96,6 +104,15 @@ pub trait Ring { /// Returns the maximum transmission unit for the ring fn mtu(&self) -> usize; + /// Returns the maximum number of GSO segments that can be used + fn max_gso(&self) -> usize; + + /// Disables the ability for the ring to send GSO messages + /// + /// This will be called in case the runtime encounters an IO error and will + /// try again with GSO disabled. + fn disable_gso(&mut self); + /// Returns all of the messages in the ring /// /// The first half of the slice should be duplicated into the second half diff --git a/quic/s2n-quic-platform/src/message/macros.rs b/quic/s2n-quic-platform/src/message/macros.rs index 868b21e2a8..a83eb2274a 100644 --- a/quic/s2n-quic-platform/src/message/macros.rs +++ b/quic/s2n-quic-platform/src/message/macros.rs @@ -4,8 +4,10 @@ #![allow(unused_macros)] macro_rules! impl_message_delegate { - ($name:ident, $field:tt) => { + ($name:ident, $field:tt, $field_ty:ty) => { impl $crate::message::Message for $name { + const SUPPORTS_GSO: bool = <$field_ty as $crate::message::Message>::SUPPORTS_GSO; + fn ecn(&self) -> ExplicitCongestionNotification { $crate::message::Message::ecn(&self.$field) } @@ -22,10 +24,6 @@ macro_rules! impl_message_delegate { $crate::message::Message::set_remote_address(&mut self.$field, remote_address) } - fn reset_remote_address(&mut self) { - $crate::message::Message::reset_remote_address(&mut self.$field) - } - fn payload_len(&self) -> usize { $crate::message::Message::payload_len(&self.$field) } @@ -34,6 +32,14 @@ macro_rules! impl_message_delegate { $crate::message::Message::set_payload_len(&mut self.$field, payload_len) } + fn set_segment_size(&mut self, size: usize) { + $crate::message::Message::set_segment_size(&mut self.$field, size) + } + + unsafe fn reset(&mut self, mtu: usize) { + $crate::message::Message::reset(&mut self.$field, mtu) + } + fn replicate_fields_from(&mut self, other: &Self) { $crate::message::Message::replicate_fields_from(&mut self.$field, &other.$field) } diff --git a/quic/s2n-quic-platform/src/message/mmsg.rs b/quic/s2n-quic-platform/src/message/mmsg.rs index 7af397d4aa..0ca15e0cc0 100644 --- a/quic/s2n-quic-platform/src/message/mmsg.rs +++ b/quic/s2n-quic-platform/src/message/mmsg.rs @@ -1,10 +1,13 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::message::{msg::Ring as MsgRing, Message as MessageTrait}; +use crate::message::{ + msg::{self, Ring as MsgRing}, + Message as MessageTrait, +}; use alloc::vec::Vec; use core::{fmt, mem::zeroed}; -use libc::{iovec, mmsghdr, sockaddr_in6}; +use libc::mmsghdr; use s2n_quic_core::{ inet::{ExplicitCongestionNotification, SocketAddress}, io::{rx, tx}, @@ -13,7 +16,7 @@ use s2n_quic_core::{ #[repr(transparent)] pub struct Message(pub(crate) mmsghdr); -impl_message_delegate!(Message, 0); +impl_message_delegate!(Message, 0, mmsghdr); impl fmt::Debug for Message { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -26,44 +29,62 @@ impl fmt::Debug for Message { } impl MessageTrait for mmsghdr { + const SUPPORTS_GSO: bool = true; + + #[inline] fn ecn(&self) -> ExplicitCongestionNotification { self.msg_hdr.ecn() } + #[inline] fn set_ecn(&mut self, ecn: ExplicitCongestionNotification) { self.msg_hdr.set_ecn(ecn) } + #[inline] fn remote_address(&self) -> Option { self.msg_hdr.remote_address() } + #[inline] fn set_remote_address(&mut self, remote_address: &SocketAddress) { self.msg_hdr.set_remote_address(remote_address) } - fn reset_remote_address(&mut self) { - self.msg_hdr.reset_remote_address() - } - + #[inline] fn payload_len(&self) -> usize { self.msg_len as usize } + #[inline] unsafe fn set_payload_len(&mut self, len: usize) { debug_assert!(len <= core::u32::MAX as usize); self.msg_len = len as _; self.msg_hdr.set_payload_len(len); } + #[inline] + fn set_segment_size(&mut self, size: usize) { + self.msg_hdr.set_segment_size(size) + } + + #[inline] + unsafe fn reset(&mut self, mtu: usize) { + self.set_payload_len(mtu); + self.msg_hdr.reset(mtu) + } + + #[inline] fn payload_ptr(&self) -> *const u8 { self.msg_hdr.payload_ptr() } + #[inline] fn payload_ptr_mut(&mut self) -> *mut u8 { self.msg_hdr.payload_ptr_mut() } + #[inline] fn replicate_fields_from(&mut self, other: &Self) { self.msg_len = other.msg_len; self.msg_hdr.replicate_fields_from(&other.msg_hdr) @@ -72,18 +93,7 @@ impl MessageTrait for mmsghdr { pub struct Ring { messages: Vec, - - // this field holds references to allocated payloads, but is never read directly - #[allow(dead_code)] - payloads: Payloads, - - // this field holds references to allocated iovecs, but is never read directly - #[allow(dead_code)] - iovecs: Vec, - - // this field holds references to allocated msg_names, but is never read directly - #[allow(dead_code)] - msg_names: Vec, + storage: msg::Storage, } /// Even though `Ring` contains raw pointers, it owns all of the data @@ -92,18 +102,19 @@ unsafe impl Send for Ring {} impl Default for Ring { fn default() -> Self { - Self::new(Payloads::default()) + Self::new( + Payloads::default(), + crate::features::get().gso.default_max_segments(), + ) } } impl Ring { - pub fn new(payloads: Payloads) -> Self { + pub fn new(payloads: Payloads, max_gso: usize) -> Self { let MsgRing { mut messages, - payloads, - iovecs, - msg_names, - } = MsgRing::new(payloads); + storage, + } = MsgRing::new(payloads, max_gso); // convert msghdr into mmsghdr let messages = messages @@ -117,30 +128,40 @@ impl Ring { }) .collect(); - Self { - messages, - payloads, - iovecs, - msg_names, - } + Self { messages, storage } } } impl super::Ring for Ring { type Message = Message; + #[inline] fn len(&self) -> usize { - self.payloads.len() + self.messages.len() / 2 } + #[inline] fn mtu(&self) -> usize { - self.payloads.mtu() + self.storage.mtu() + } + + #[inline] + fn max_gso(&self) -> usize { + self.storage.max_gso() + } + + #[inline] + fn disable_gso(&mut self) { + // TODO recompute message offsets + self.storage.disable_gso() } + #[inline] fn as_slice(&self) -> &[Self::Message] { &self.messages[..] } + #[inline] fn as_mut_slice(&mut self) -> &mut [Self::Message] { &mut self.messages[..] } @@ -169,28 +190,34 @@ impl tx::Entry for Message { Ok(len) } + #[inline] fn payload(&self) -> &[u8] { MessageTrait::payload(self) } + #[inline] fn payload_mut(&mut self) -> &mut [u8] { MessageTrait::payload_mut(self) } } impl rx::Entry for Message { + #[inline] fn remote_address(&self) -> Option { MessageTrait::remote_address(self) } + #[inline] fn ecn(&self) -> ExplicitCongestionNotification { MessageTrait::ecn(self) } + #[inline] fn payload(&self) -> &[u8] { MessageTrait::payload(self) } + #[inline] fn payload_mut(&mut self) -> &mut [u8] { MessageTrait::payload_mut(self) } diff --git a/quic/s2n-quic-platform/src/message/msg.rs b/quic/s2n-quic-platform/src/message/msg.rs index 032cbcd92b..79e3aafc2b 100644 --- a/quic/s2n-quic-platform/src/message/msg.rs +++ b/quic/s2n-quic-platform/src/message/msg.rs @@ -1,11 +1,12 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::message::Message as MessageTrait; +use crate::{features, message::Message as MessageTrait}; use alloc::vec::Vec; use core::{ fmt, mem::{size_of, zeroed}, + pin::Pin, }; use libc::{c_void, iovec, msghdr, sockaddr_in, sockaddr_in6, AF_INET, AF_INET6}; use s2n_quic_core::{ @@ -19,7 +20,7 @@ use s2n_quic_core::inet::{IpV6Address, SocketAddressV6}; #[repr(transparent)] pub struct Message(pub(crate) msghdr); -impl_message_delegate!(Message, 0); +impl_message_delegate!(Message, 0, msghdr); impl fmt::Debug for Message { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -31,6 +32,12 @@ impl fmt::Debug for Message { } } +/// The is the maximum number of bytes allocated for cmsg data +/// +/// This should be enough for UDP_SEGMENT + IP_TOS + IP_PKTINFO. It may need to be increased +/// to allow for future control messages. +const MAX_CMSG_LEN: usize = 128; + impl Message { fn new( iovec: *mut iovec, @@ -55,11 +62,15 @@ impl Message { } impl MessageTrait for msghdr { + const SUPPORTS_GSO: bool = true; + + #[inline] fn ecn(&self) -> ExplicitCongestionNotification { // TODO support ecn ExplicitCongestionNotification::default() } + #[inline] fn set_ecn(&mut self, _ecn: ExplicitCongestionNotification) { // TODO support ecn } @@ -109,30 +120,72 @@ impl MessageTrait for msghdr { } } - fn reset_remote_address(&mut self) { - self.msg_namelen = size_of::() as _; - } - + #[inline] fn payload_len(&self) -> usize { debug_assert!(!self.msg_iov.is_null()); unsafe { (*self.msg_iov).iov_len } } + #[inline] unsafe fn set_payload_len(&mut self, payload_len: usize) { debug_assert!(!self.msg_iov.is_null()); (*self.msg_iov).iov_len = payload_len; } + #[inline] + fn set_segment_size(&mut self, size: usize) { + let cmsg = + unsafe { core::slice::from_raw_parts_mut(self.msg_control as *mut u8, MAX_CMSG_LEN) }; + let remaining = &mut cmsg[(self.msg_controllen as usize)..]; + let len = features::Gso::set(remaining, size); + // add the values as a usize to make sure we work cross-platform + self.msg_controllen = (len + self.msg_controllen as usize) as _; + } + + #[inline] + unsafe fn reset(&mut self, mtu: usize) { + // reset the payload + self.set_payload_len(mtu); + + // reset the address + self.msg_namelen = size_of::() as _; + + if cfg!(debug_assertions) && self.msg_controllen == 0 { + // make sure nothing was written to the control message if it was set to 0 + assert!( + core::slice::from_raw_parts_mut(self.msg_control as *mut u8, MAX_CMSG_LEN) + .iter() + .all(|v| *v == 0) + ) + } + + // reset the control messages + let cmsg = + core::slice::from_raw_parts_mut(self.msg_control as *mut u8, self.msg_controllen as _); + + for byte in cmsg.iter_mut() { + *byte = 0; + } + self.msg_controllen = 0; + } + + #[inline] fn replicate_fields_from(&mut self, other: &Self) { debug_assert_eq!( self.msg_name, other.msg_name, "msg_name needs to point to the same data" ); + debug_assert_eq!( + self.msg_control, other.msg_control, + "msg_control needs to point to the same data" + ); debug_assert_eq!(self.msg_iov, other.msg_iov); debug_assert_eq!(self.msg_iovlen, other.msg_iovlen); self.msg_namelen = other.msg_namelen; + self.msg_controllen = other.msg_controllen; } + #[inline] fn payload_ptr(&self) -> *const u8 { unsafe { let iovec = &*self.msg_iov; @@ -140,6 +193,7 @@ impl MessageTrait for msghdr { } } + #[inline] fn payload_ptr_mut(&mut self) -> *mut u8 { unsafe { let iovec = &mut *self.msg_iov; @@ -150,18 +204,49 @@ impl MessageTrait for msghdr { pub struct Ring { pub(crate) messages: Vec, + pub(crate) storage: Storage, +} +pub struct Storage { // this field holds references to allocated payloads, but is never read directly #[allow(dead_code)] - pub(crate) payloads: Payloads, + pub(crate) payloads: Pin, // this field holds references to allocated iovecs, but is never read directly #[allow(dead_code)] - pub(crate) iovecs: Vec, + pub(crate) iovecs: Pin>, + + // this field holds references to allocated msg_names, but is never read directly + #[allow(dead_code)] + pub(crate) msg_names: Pin>, // this field holds references to allocated msg_names, but is never read directly #[allow(dead_code)] - pub(crate) msg_names: Vec, + pub(crate) cmsgs: Pin>, + + /// The maximum payload for any given message + mtu: usize, + + /// The maximum number of segments that can be offloaded in a single message + max_gso: usize, +} + +impl Storage { + #[inline] + pub fn mtu(&self) -> usize { + self.mtu + } + + #[inline] + pub fn max_gso(&self) -> usize { + self.max_gso + } + + #[inline] + pub fn disable_gso(&mut self) { + // TODO recompute message offsets + self.max_gso = 1; + } } /// Even though `Ring` contains raw pointers, it owns all of the data @@ -170,34 +255,47 @@ unsafe impl Send for Ring {} impl Default for Ring { fn default() -> Self { - Self::new(Payloads::default()) + Self::new( + Payloads::default(), + crate::features::get().gso.default_max_segments(), + ) } } impl Ring { - pub fn new(mut payloads: Payloads) -> Self { - let capacity = payloads.len(); + pub fn new(payloads: Payloads, max_gso: usize) -> Self { + assert!(max_gso < crate::features::get().gso.max_segments()); + let mtu = payloads.mtu(); + let capacity = payloads.len() / mtu / max_gso; - let mut iovecs = Vec::with_capacity(capacity); - let mut msg_names = Vec::with_capacity(capacity); + let mut payloads = Pin::new(payloads); + let mut iovecs = Pin::new(vec![unsafe { zeroed() }; capacity].into_boxed_slice()); + let mut msg_names = Pin::new(vec![unsafe { zeroed() }; capacity].into_boxed_slice()); + let mut cmsgs = Pin::new(vec![0u8; capacity * MAX_CMSG_LEN].into_boxed_slice()); // double message capacity to enable contiguous access let mut messages = Vec::with_capacity(capacity * 2); + let mut payload_buf = &mut payloads.as_mut()[..]; + let mut cmsg_buf = &mut cmsgs.as_mut()[..]; + for index in 0..capacity { + let (payload, remaining) = payload_buf.split_at_mut(mtu * max_gso); + payload_buf = remaining; + let (cmsg, remaining) = cmsg_buf.split_at_mut(MAX_CMSG_LEN); + cmsg_buf = remaining; + let mut iovec = unsafe { zeroed::() }; - iovec.iov_base = payloads[index].as_mut_ptr() as _; + iovec.iov_base = payload.as_mut_ptr() as _; iovec.iov_len = mtu; - iovecs.push(iovec); - - msg_names.push(unsafe { zeroed() }); + iovecs[index] = iovec; let msg = Message::new( (&mut iovecs[index]) as *mut _, (&mut msg_names[index]) as *mut _ as *mut _, size_of::(), - core::ptr::null_mut(), + cmsg as *mut _ as *mut _, 0, ); @@ -210,9 +308,14 @@ impl Ring { Self { messages, - payloads, - iovecs, - msg_names, + storage: Storage { + payloads, + iovecs, + msg_names, + cmsgs, + mtu, + max_gso, + }, } } } @@ -220,18 +323,32 @@ impl Ring { impl super::Ring for Ring { type Message = Message; + #[inline] fn len(&self) -> usize { - self.payloads.len() + self.messages.len() / 2 } + #[inline] fn mtu(&self) -> usize { - self.payloads.mtu() + self.storage.mtu() } + #[inline] + fn max_gso(&self) -> usize { + // TODO recompute message offsets + self.storage.max_gso() + } + + fn disable_gso(&mut self) { + self.storage.disable_gso() + } + + #[inline] fn as_slice(&self) -> &[Self::Message] { &self.messages[..] } + #[inline] fn as_mut_slice(&mut self) -> &mut [Self::Message] { &mut self.messages[..] } @@ -260,28 +377,34 @@ impl tx::Entry for Message { Ok(len) } + #[inline] fn payload(&self) -> &[u8] { MessageTrait::payload(self) } + #[inline] fn payload_mut(&mut self) -> &mut [u8] { MessageTrait::payload_mut(self) } } impl rx::Entry for Message { + #[inline] fn remote_address(&self) -> Option { MessageTrait::remote_address(self) } + #[inline] fn ecn(&self) -> ExplicitCongestionNotification { MessageTrait::ecn(self) } + #[inline] fn payload(&self) -> &[u8] { MessageTrait::payload(self) } + #[inline] fn payload_mut(&mut self) -> &mut [u8] { MessageTrait::payload_mut(self) } @@ -307,6 +430,9 @@ mod tests { msghdr.msg_name = &mut msgname as *mut _ as *mut _; msghdr.msg_namelen = size_of::() as _; + let mut iovec = unsafe { zeroed::() }; + msghdr.msg_iov = &mut iovec; + let mut message = Message(msghdr); check!() @@ -315,7 +441,9 @@ mod tests { .for_each(|addr| { #[cfg(not(feature = "ipv6"))] let addr = addr.into(); - message.reset_remote_address(); + unsafe { + message.reset(0); + } message.set_remote_address(&addr); #[cfg(all(target_os = "macos", feature = "ipv6"))] diff --git a/quic/s2n-quic-platform/src/message/queue.rs b/quic/s2n-quic-platform/src/message/queue.rs index 3cc988ba28..06b01d4eb0 100644 --- a/quic/s2n-quic-platform/src/message/queue.rs +++ b/quic/s2n-quic-platform/src/message/queue.rs @@ -84,6 +84,14 @@ impl Queue { self.ring.mtu() } + pub fn max_gso(&self) -> usize { + self.ring.max_gso() + } + + pub fn disable_gso(&mut self) { + self.ring.disable_gso() + } + /// Returns the number of slots in the buffer pub fn capacity(&self) -> usize { self.ring.len() @@ -102,22 +110,28 @@ impl Queue { /// Returns a slice of all of the `free` messages pub fn free_mut(&mut self) -> Free { let mtu = self.mtu(); + let max_gso = self.max_gso(); Slice { messages: self.ring.as_mut_slice(), primary: &mut self.free, secondary: &mut self.occupied, behavior: behavior::Free { mtu }, + max_gso, + gso_segment: None, } } /// Returns a slice of all of the `occupied` messages pub fn occupied_mut(&mut self) -> Occupied { let mtu = self.mtu(); + let max_gso = self.max_gso(); Slice { messages: self.ring.as_mut_slice(), primary: &mut self.occupied, secondary: &mut self.free, behavior: behavior::Occupied { mtu }, + max_gso, + gso_segment: None, } } @@ -126,11 +140,14 @@ impl Queue { /// The messages will be wiped on release. pub fn occupied_wipe_mut(&mut self) -> OccupiedWipe { let mtu = self.mtu(); + let max_gso = self.max_gso(); Slice { messages: self.ring.as_mut_slice(), primary: &mut self.occupied, secondary: &mut self.free, behavior: behavior::OccupiedWipe { mtu }, + max_gso, + gso_segment: None, } } } @@ -284,7 +301,9 @@ mod tests { .for_each(|(capacity, ops)| { use $ring; let payloads = VecBuffer::new(*capacity, MTU); - let ring = Ring::new(payloads); + // limit GSO segments as this harness assumes no GSO + let max_gso = 1; + let ring = Ring::new(payloads, max_gso); let queue = Queue::new(ring); assert_eq!(queue.mtu(), MTU); check(queue, *capacity, ops); diff --git a/quic/s2n-quic-platform/src/message/queue/behavior.rs b/quic/s2n-quic-platform/src/message/queue/behavior.rs index bb6a2a59f7..3667ddc54e 100644 --- a/quic/s2n-quic-platform/src/message/queue/behavior.rs +++ b/quic/s2n-quic-platform/src/message/queue/behavior.rs @@ -172,8 +172,7 @@ fn reset(messages: &mut [Message], mtu: usize) { for message in messages { unsafe { // Safety: the payloads should always be allocated regions of MTU - message.set_payload_len(mtu); - message.reset_remote_address(); + message.reset(mtu); } } } @@ -191,8 +190,7 @@ fn wipe(messages: &mut [Message], mtu: usize) { unsafe { // Safety: the payloads should always be allocated regions of MTU - message.set_payload_len(mtu); - message.reset_remote_address(); + message.reset(mtu); } } } diff --git a/quic/s2n-quic-platform/src/message/queue/slice.rs b/quic/s2n-quic-platform/src/message/queue/slice.rs index 42b0535b1b..c1f57609d6 100644 --- a/quic/s2n-quic-platform/src/message/queue/slice.rs +++ b/quic/s2n-quic-platform/src/message/queue/slice.rs @@ -8,7 +8,7 @@ use s2n_quic_core::io::{rx, tx}; /// A view of the currently enqueued messages for a given segment #[derive(Debug)] -pub struct Slice<'a, Message, Behavior> { +pub struct Slice<'a, Message: message::Message, Behavior> { /// A slice of all of the messages in the buffer pub(crate) messages: &'a mut [Message], /// Reference to the primary segment @@ -17,17 +17,25 @@ pub struct Slice<'a, Message, Behavior> { pub(crate) secondary: &'a mut Segment, /// Reset the messages after use pub(crate) behavior: Behavior, + /// The maximum allowed number of GSO segments + pub(crate) max_gso: usize, + /// The index to the previously pushed segment + pub(crate) gso_segment: Option, } -impl<'a, Message: message::Message, B: Behavior> Slice<'a, Message, B> { - pub fn into_slice_mut(self) -> &'a mut [Message] { - &mut self.messages[self.primary.range()] - } +#[derive(Debug, Default)] +pub struct GsoSegment { + index: usize, + count: usize, + size: usize, +} +impl<'a, Message: message::Message, B: Behavior> Slice<'a, Message, B> { /// Finishes the borrow of the `Slice` with a specified `count` /// /// Calling this method will move `count` messages from one segment /// to the other; e.g. `ready` to `pending`. + #[inline] pub fn finish(mut self, count: usize) { self.advance(count); } @@ -39,6 +47,8 @@ impl<'a, Message: message::Message, B: Behavior> Slice<'a, Message, B> { "cannot finish more messages than available" ); + self.flush_gso(); + let (start, end, overflow, capacity) = self.compute_behavior_arguments(count); let (primary, secondary) = self.messages.split_at_mut(capacity); @@ -49,7 +59,9 @@ impl<'a, Message: message::Message, B: Behavior> Slice<'a, Message, B> { } /// Preserves the messages in the current segment - pub fn cancel(self, count: usize) { + pub fn cancel(mut self, count: usize) { + self.flush_gso(); + let (start, end, overflow, capacity) = self.compute_behavior_arguments(count); let (primary, secondary) = self.messages.split_at_mut(capacity); @@ -72,15 +84,50 @@ impl<'a, Message: message::Message, B: Behavior> Slice<'a, Message, B> { } } -impl<'a, Message, R> Deref for Slice<'a, Message, R> { +impl<'a, Message: message::Message, B> Slice<'a, Message, B> { + fn flush_gso(&mut self) { + if let Some(gso) = self.gso_segment.take() { + if gso.count > 1 { + let mid = self.messages.len() / 2; + let (primary, secondary) = self.messages.split_at_mut(mid); + let index = gso.index; + + // try to wrap around the midpoint + let (primary, secondary) = if let Some(index) = index.checked_sub(mid) { + let primary = &mut primary[index]; + let secondary = &mut secondary[index]; + (secondary, primary) + } else { + let primary = &mut primary[index]; + let secondary = &mut secondary[index]; + (primary, secondary) + }; + + primary.set_segment_size(gso.size); + secondary.replicate_fields_from(primary); + } + } + } +} + +impl<'a, Message: message::Message, R> Drop for Slice<'a, Message, R> { + #[inline] + fn drop(&mut self) { + self.flush_gso() + } +} + +impl<'a, Message: message::Message, R> Deref for Slice<'a, Message, R> { type Target = [Message]; + #[inline] fn deref(&self) -> &Self::Target { &self.messages[self.primary.range()] } } -impl<'a, Message, R> DerefMut for Slice<'a, Message, R> { +impl<'a, Message: message::Message, R> DerefMut for Slice<'a, Message, R> { + #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.messages[self.primary.range()] } @@ -89,15 +136,18 @@ impl<'a, Message, R> DerefMut for Slice<'a, Message, R> { impl<'a, Message: rx::Entry + message::Message, B: Behavior> rx::Queue for Slice<'a, Message, B> { type Entry = Message; + #[inline] fn as_slice_mut(&mut self) -> &mut [Message] { let range = self.primary.range(); &mut self.messages[range] } + #[inline] fn len(&self) -> usize { self.primary.len } + #[inline] fn finish(&mut self, count: usize) { self.advance(count) } @@ -106,26 +156,116 @@ impl<'a, Message: rx::Entry + message::Message, B: Behavior> rx::Queue for Slice impl<'a, Message: tx::Entry + message::Message, B: Behavior> tx::Queue for Slice<'a, Message, B> { type Entry = Message; - fn push(&mut self, message: M) -> Result { + fn push(&mut self, mut message: M) -> Result { + if Message::SUPPORTS_GSO { + if let Some(gso) = self.gso_segment.as_mut() { + let max_segments = self.max_gso; + debug_assert!( + max_segments > 1, + "gso_segment should only be set when max_gso > 1" + ); + + let prev_message = &mut self.messages[gso.index]; + + if message.can_gso() + && prev_message.remote_address() == Some(message.remote_address()) + { + debug_assert!( + gso.count < max_segments, + "{} cannot exceed {}", + gso.count, + max_segments + ); + + let payload_len = prev_message.payload_len(); + + unsafe { + // Safety: all payloads should have enough capacity to extend max_segments * + // gso.size + prev_message.set_payload_len(payload_len + gso.size); + } + + // allow the message to write up to `gso.size` bytes + let buffer = &mut message::Message::payload_mut(prev_message)[payload_len..]; + + match message.write_payload(buffer) { + 0 => { + unsafe { + // revert the len to what it was before + prev_message.set_payload_len(payload_len); + } + return Err(tx::Error::EmptyPayload); + } + size => { + unsafe { + // set the len to the actual amount written to the payload + prev_message.set_payload_len(payload_len + size); + } + // increment the number of segments + gso.count += 1; + + debug_assert!( + gso.count <= max_segments, + "{} cannot exceed {}", + gso.count, + max_segments + ); + + let index = gso.index; + + // the last segment can be smaller but we can't write any more if it is + let size_mismatch = gso.size != size; + + // we're bounded by the max_segments amount + let at_segment_limit = gso.count >= max_segments; + + // we also can't write more data than u16::MAX + let at_payload_limit = gso.size * (gso.count + 1) > u16::MAX as usize; + + if size_mismatch || at_segment_limit || at_payload_limit { + self.flush_gso(); + } + + return Ok(index); + } + } + } + + // move on to the next index + self.flush_gso(); + } + } + let index = self .primary - .index(&self.secondary) + .index(self.secondary) .ok_or(tx::Error::AtCapacity)?; - self.messages[index].set(message)?; + let size = self.messages[index].set(message)?; self.advance(1); + if Message::SUPPORTS_GSO && self.max_gso > 1 { + self.gso_segment = Some(GsoSegment { + index, + count: 1, + size, + }); + } + Ok(index) } + #[inline] fn as_slice_mut(&mut self) -> &mut [Message] { &mut self.messages[self.secondary.range()] } + #[inline] fn capacity(&self) -> usize { self.primary.len } + #[inline] fn len(&self) -> usize { self.secondary.len } diff --git a/quic/s2n-quic-platform/src/message/simple.rs b/quic/s2n-quic-platform/src/message/simple.rs index 86418ca91d..14993a6846 100644 --- a/quic/s2n-quic-platform/src/message/simple.rs +++ b/quic/s2n-quic-platform/src/message/simple.rs @@ -3,6 +3,7 @@ use crate::message::Message as MessageTrait; use alloc::vec::Vec; +use core::pin::Pin; use s2n_quic_core::{ inet::{ExplicitCongestionNotification, SocketAddress}, io::{rx, tx}, @@ -19,6 +20,8 @@ pub struct Message { } impl MessageTrait for Message { + const SUPPORTS_GSO: bool = false; + fn ecn(&self) -> ExplicitCongestionNotification { ExplicitCongestionNotification::default() } @@ -41,10 +44,6 @@ impl MessageTrait for Message { self.address = remote_address; } - fn reset_remote_address(&mut self) { - self.address = Default::default(); - } - fn payload_len(&self) -> usize { self.payload_len as usize } @@ -53,6 +52,15 @@ impl MessageTrait for Message { self.payload_len = len; } + fn set_segment_size(&mut self, _size: usize) { + panic!("segments are not supported in simple messages"); + } + + unsafe fn reset(&mut self, mtu: usize) { + self.address = Default::default(); + self.set_payload_len(mtu) + } + fn payload_ptr(&self) -> *const u8 { self.payload_ptr as *const _ } @@ -73,7 +81,9 @@ pub struct Ring { // this field holds references to allocated payloads, but is never read directly #[allow(dead_code)] - payloads: Payloads, + payloads: Pin, + + mtu: usize, } /// Even though `Ring` contains raw pointers, it owns all of the data @@ -82,20 +92,27 @@ unsafe impl Send for Ring {} impl Default for Ring { fn default() -> Self { - Self::new(Payloads::default()) + Self::new(Payloads::default(), 1) } } impl Ring { - pub fn new(mut payloads: Payloads) -> Self { - let capacity = payloads.len(); + pub fn new(payloads: Payloads, _max_gso: usize) -> Self { let mtu = payloads.mtu(); + let capacity = payloads.len() / mtu; + + let mut payloads = Pin::new(payloads); // double message capacity to enable contiguous access let mut messages = Vec::with_capacity(capacity * 2); - for index in 0..capacity { - let payload_ptr = payloads[index].as_mut_ptr() as _; + let mut buf = &mut payloads.as_mut()[..]; + + for _ in 0..capacity { + let (payload, remaining) = buf.split_at_mut(mtu); + buf = remaining; + + let payload_ptr = payload.as_mut_ptr() as _; messages.push(Message { payload_ptr, payload_len: mtu, @@ -107,25 +124,43 @@ impl Ring { messages.push(messages[index]); } - Self { messages, payloads } + Self { + messages, + payloads, + mtu, + } } } impl super::Ring for Ring { type Message = Message; + #[inline] fn len(&self) -> usize { - self.payloads.len() + self.messages.len() / 2 } + #[inline] fn mtu(&self) -> usize { - self.payloads.mtu() + self.mtu + } + + #[inline] + fn max_gso(&self) -> usize { + 1 + } + + #[inline] + fn disable_gso(&mut self) { + panic!("GSO is not supported by simple messages"); } + #[inline] fn as_slice(&self) -> &[Self::Message] { &self.messages[..] } + #[inline] fn as_mut_slice(&mut self) -> &mut [Self::Message] { &mut self.messages[..] } @@ -152,28 +187,34 @@ impl tx::Entry for Message { Ok(len) } + #[inline] fn payload(&self) -> &[u8] { MessageTrait::payload(self) } + #[inline] fn payload_mut(&mut self) -> &mut [u8] { MessageTrait::payload_mut(self) } } impl rx::Entry for Message { + #[inline] fn remote_address(&self) -> Option { MessageTrait::remote_address(self) } + #[inline] fn ecn(&self) -> ExplicitCongestionNotification { MessageTrait::ecn(self) } + #[inline] fn payload(&self) -> &[u8] { MessageTrait::payload(self) } + #[inline] fn payload_mut(&mut self) -> &mut [u8] { MessageTrait::payload_mut(self) } diff --git a/quic/s2n-quic-platform/src/socket/mmsg.rs b/quic/s2n-quic-platform/src/socket/mmsg.rs index 7b00c9fa29..964112928c 100644 --- a/quic/s2n-quic-platform/src/socket/mmsg.rs +++ b/quic/s2n-quic-platform/src/socket/mmsg.rs @@ -14,8 +14,8 @@ use std::{io, os::unix::io::AsRawFd}; pub struct Queue(queue::Queue>); impl Queue { - pub fn new(buffer: B) -> Self { - let queue = queue::Queue::new(Ring::new(buffer)); + pub fn new(buffer: B, max_gso: usize) -> Self { + let queue = queue::Queue::new(Ring::new(buffer, max_gso)); Self(queue) } @@ -79,6 +79,30 @@ impl Queue { entries.finish(count); Ok(count) } + Err(err) if err.kind() == io::ErrorKind::Interrupted => { + entries.cancel(0); + Ok(0) + } + Err(err) if err.kind() == io::ErrorKind::PermissionDenied => { + // just drop the packets on permission errors - most likely a firewall issue + let count = vlen as usize; + entries.finish(count); + Ok(count) + } + // check to see if we need to disable GSO + Err(err) if unsafe { *libc::__errno_location() } == libc::EIO => { + let count = vlen as usize; + entries.finish(count); + + if self.0.max_gso() > 1 { + self.0.disable_gso(); + // unfortunately we've already assembled GSO packets so just drop them + // and wait for a retransmission + Ok(count) + } else { + Err(err) + } + } Err(err) => { entries.cancel(0); Err(err) @@ -143,6 +167,10 @@ impl Queue { entries.finish(count); Ok(count) } + Err(err) if err.kind() == io::ErrorKind::Interrupted => { + entries.cancel(0); + Ok(0) + } Err(err) => { entries.cancel(0); Err(err) diff --git a/quic/s2n-quic-platform/src/socket/msg.rs b/quic/s2n-quic-platform/src/socket/msg.rs index d527f69d9c..2405dae152 100644 --- a/quic/s2n-quic-platform/src/socket/msg.rs +++ b/quic/s2n-quic-platform/src/socket/msg.rs @@ -14,8 +14,8 @@ use std::{io, os::unix::io::AsRawFd}; pub struct Queue(queue::Queue>); impl Queue { - pub fn new(buffer: B) -> Self { - let queue = queue::Queue::new(Ring::new(buffer)); + pub fn new(buffer: B, max_segments: usize) -> Self { + let queue = queue::Queue::new(Ring::new(buffer, max_segments)); Self(queue) } @@ -59,14 +59,34 @@ impl Queue { Ok(_len) => { count += 1; } - Err(err) => { - if count > 0 && err.kind() == io::ErrorKind::WouldBlock { - break; + Err(err) if count > 0 && err.kind() == io::ErrorKind::WouldBlock => { + break; + } + Err(err) if err.kind() == io::ErrorKind::Interrupted => { + break; + } + Err(err) if err.kind() == io::ErrorKind::PermissionDenied => { + // just drop the packets on permission errors - most likely a firewall issue + count += 1; + } + // check to see if we need to disable GSO + Err(err) if unsafe { *libc::__errno_location() } == libc::EIO => { + // unfortunately we've already assembled GSO packets so just drop them + // and wait for a retransmission + let len = entries.len(); + entries.finish(len); + + if self.0.max_gso() > 1 { + self.0.disable_gso(); + return Ok(count); } else { - entries.finish(count); return Err(err); } } + Err(err) => { + entries.finish(count); + return Err(err); + } } } @@ -127,6 +147,9 @@ impl Queue { count += 1; } + Err(err) if err.kind() == io::ErrorKind::Interrupted => { + break; + } Err(err) => { if count > 0 && err.kind() == io::ErrorKind::WouldBlock { break; diff --git a/quic/s2n-quic-platform/src/socket/std.rs b/quic/s2n-quic-platform/src/socket/std.rs index c4181a3083..95b1c58f5a 100644 --- a/quic/s2n-quic-platform/src/socket/std.rs +++ b/quic/s2n-quic-platform/src/socket/std.rs @@ -40,6 +40,8 @@ impl Socket for std::net::UdpSocket { pub trait Error { fn would_block(&self) -> bool; + fn was_interrupted(&self) -> bool; + fn permission_denied(&self) -> bool; } #[cfg(feature = "std")] @@ -47,6 +49,14 @@ impl Error for std::io::Error { fn would_block(&self) -> bool { self.kind() == std::io::ErrorKind::WouldBlock } + + fn was_interrupted(&self) -> bool { + self.kind() == std::io::ErrorKind::Interrupted + } + + fn permission_denied(&self) -> bool { + self.kind() == std::io::ErrorKind::PermissionDenied + } } #[derive(Debug, Default)] @@ -54,7 +64,7 @@ pub struct Queue(queue::Queue>); impl Queue { pub fn new(buffer: B) -> Self { - let queue = queue::Queue::new(Ring::new(buffer)); + let queue = queue::Queue::new(Ring::new(buffer, 1)); Self(queue) } @@ -77,13 +87,15 @@ impl Queue { Ok(_) => { count += 1; } + Err(err) if count > 0 && err.would_block() => { + break; + } + Err(err) if err.was_interrupted() || err.permission_denied() => { + break; + } Err(err) => { - if count > 0 && err.would_block() { - break; - } else { - entries.finish(count); - return Err(err); - } + entries.finish(count); + return Err(err); } } } @@ -114,13 +126,15 @@ impl Queue { count += 1; } Ok((_payload_len, None)) => {} + Err(err) if count > 0 && err.would_block() => { + break; + } + Err(err) if err.was_interrupted() => { + break; + } Err(err) => { - if count > 0 && err.would_block() { - break; - } else { - entries.finish(count); - return Err(err); - } + entries.finish(count); + return Err(err); } } } diff --git a/quic/s2n-quic-transport/src/connection/close_sender.rs b/quic/s2n-quic-transport/src/connection/close_sender.rs index a7ce908959..ad0dbc21d1 100644 --- a/quic/s2n-quic-transport/src/connection/close_sender.rs +++ b/quic/s2n-quic-transport/src/connection/close_sender.rs @@ -115,22 +115,32 @@ pub struct Transmission<'a, CC: CongestionController> { } impl<'a, CC: CongestionController> tx::Message for Transmission<'a, CC> { + #[inline] fn remote_address(&mut self) -> SocketAddress { self.path.peer_socket_address } + #[inline] fn ecn(&mut self) -> ExplicitCongestionNotification { ExplicitCongestionNotification::default() } + #[inline] fn ipv6_flow_label(&mut self) -> u32 { 0 } + #[inline] + fn can_gso(&self) -> bool { + true + } + + #[inline] fn delay(&mut self) -> Duration { Duration::default() } + #[inline] fn write_payload(&mut self, buffer: &mut [u8]) -> usize { let len = self.packet.len(); diff --git a/quic/s2n-quic-transport/src/connection/transmission.rs b/quic/s2n-quic-transport/src/connection/transmission.rs index ae1a00f6cb..742d7e0650 100644 --- a/quic/s2n-quic-transport/src/connection/transmission.rs +++ b/quic/s2n-quic-transport/src/connection/transmission.rs @@ -57,24 +57,33 @@ pub struct ConnectionTransmission<'a, 'sub, Config: endpoint::Config> { } impl<'a, 'sub, Config: endpoint::Config> tx::Message for ConnectionTransmission<'a, 'sub, Config> { + #[inline] fn remote_address(&mut self) -> SocketAddress { self.context.path().peer_socket_address } + #[inline] fn ecn(&mut self) -> ExplicitCongestionNotification { self.context.ecn } + #[inline] fn delay(&mut self) -> Duration { // TODO return delay from pacer Default::default() } + #[inline] fn ipv6_flow_label(&mut self) -> u32 { // TODO compute flow label from connection id 0 } + #[inline] + fn can_gso(&self) -> bool { + !self.context.transmission_mode.is_mtu_probing() + } + fn write_payload(&mut self, buffer: &mut [u8]) -> usize { let shared_state = &mut self.shared_state; let space_manager = &mut shared_state.space_manager; diff --git a/quic/s2n-quic-transport/src/endpoint/retry.rs b/quic/s2n-quic-transport/src/endpoint/retry.rs index 6cca9f1664..49505cd762 100644 --- a/quic/s2n-quic-transport/src/endpoint/retry.rs +++ b/quic/s2n-quic-transport/src/endpoint/retry.rs @@ -112,22 +112,32 @@ impl AsRef<[u8]> for Transmission { } impl tx::Message for &Transmission { + #[inline] fn remote_address(&mut self) -> SocketAddress { self.remote_address } + #[inline] fn ecn(&mut self) -> ExplicitCongestionNotification { Default::default() } + #[inline] fn delay(&mut self) -> time::Duration { Default::default() } + #[inline] fn ipv6_flow_label(&mut self) -> u32 { 0 } + #[inline] + fn can_gso(&self) -> bool { + true + } + + #[inline] fn write_payload(&mut self, buffer: &mut [u8]) -> usize { let packet = self.as_ref(); buffer[..packet.len()].copy_from_slice(packet); diff --git a/quic/s2n-quic-transport/src/endpoint/stateless_reset.rs b/quic/s2n-quic-transport/src/endpoint/stateless_reset.rs index 517767f1e8..fe4eabb9c6 100644 --- a/quic/s2n-quic-transport/src/endpoint/stateless_reset.rs +++ b/quic/s2n-quic-transport/src/endpoint/stateless_reset.rs @@ -107,22 +107,32 @@ impl AsRef<[u8]> for Transmission { } impl tx::Message for &Transmission { + #[inline] fn remote_address(&mut self) -> SocketAddress { self.remote_address } + #[inline] fn ecn(&mut self) -> ExplicitCongestionNotification { Default::default() } + #[inline] fn delay(&mut self) -> time::Duration { Default::default() } + #[inline] fn ipv6_flow_label(&mut self) -> u32 { 0 } + #[inline] + fn can_gso(&self) -> bool { + true + } + + #[inline] fn write_payload(&mut self, buffer: &mut [u8]) -> usize { let packet = self.as_ref(); buffer[..packet.len()].copy_from_slice(packet); diff --git a/quic/s2n-quic-transport/src/endpoint/version.rs b/quic/s2n-quic-transport/src/endpoint/version.rs index 27da917096..0a45baa2a6 100644 --- a/quic/s2n-quic-transport/src/endpoint/version.rs +++ b/quic/s2n-quic-transport/src/endpoint/version.rs @@ -225,22 +225,32 @@ impl AsRef<[u8]> for Transmission { } impl tx::Message for &Transmission { + #[inline] fn remote_address(&mut self) -> SocketAddress { self.remote_address } + #[inline] fn ecn(&mut self) -> ExplicitCongestionNotification { Default::default() } + #[inline] fn delay(&mut self) -> Duration { Default::default() } + #[inline] fn ipv6_flow_label(&mut self) -> u32 { 0 } + #[inline] + fn can_gso(&self) -> bool { + true + } + + #[inline] fn write_payload(&mut self, buffer: &mut [u8]) -> usize { let packet = self.as_ref(); buffer[..packet.len()].copy_from_slice(packet); diff --git a/quic/s2n-quic-transport/src/sync/data_sender.rs b/quic/s2n-quic-transport/src/sync/data_sender.rs index 1bc0d21edb..174152acda 100644 --- a/quic/s2n-quic-transport/src/sync/data_sender.rs +++ b/quic/s2n-quic-transport/src/sync/data_sender.rs @@ -356,6 +356,7 @@ impl } /// Queries the component for any outgoing frames that need to get sent + #[inline] pub fn on_transmit( &mut self, writer_context: Writer::Context, diff --git a/quic/s2n-quic-transport/src/sync/data_sender/writer.rs b/quic/s2n-quic-transport/src/sync/data_sender/writer.rs index 1d38bfc1df..e0788de2a7 100644 --- a/quic/s2n-quic-transport/src/sync/data_sender/writer.rs +++ b/quic/s2n-quic-transport/src/sync/data_sender/writer.rs @@ -11,6 +11,7 @@ pub struct Stream; impl FrameWriter for Stream { type Context = VarInt; + #[inline] fn write_chunk( &self, offset: VarInt, @@ -48,6 +49,7 @@ impl FrameWriter for Stream { Ok(()) } + #[inline] fn write_fin( &self, offset: VarInt, diff --git a/quic/s2n-quic-transport/src/transmission/context.rs b/quic/s2n-quic-transport/src/transmission/context.rs index 5b165552e3..5aba4a5238 100644 --- a/quic/s2n-quic-transport/src/transmission/context.rs +++ b/quic/s2n-quic-transport/src/transmission/context.rs @@ -65,22 +65,27 @@ impl<'a, 'b, Config: endpoint::Config> Context<'a, 'b, Config> { } impl<'a, 'b, Config: endpoint::Config> WriteContext for Context<'a, 'b, Config> { + #[inline] fn current_time(&self) -> Timestamp { self.timestamp } + #[inline] fn transmission_constraint(&self) -> transmission::Constraint { self.transmission_constraint } + #[inline] fn transmission_mode(&self) -> Mode { self.transmission_mode } + #[inline] fn remaining_capacity(&self) -> usize { self.buffer.remaining_capacity() } + #[inline] fn write_frame< Frame: EncoderValue + AckElicitable + CongestionControlled + PathValidationProbing, >( @@ -91,6 +96,7 @@ impl<'a, 'b, Config: endpoint::Config> WriteContext for Context<'a, 'b, Config> self.write_frame_forced(frame) } + #[inline] fn write_fitted_frame< Frame: EncoderValue + AckElicitable + CongestionControlled + PathValidationProbing, >( @@ -107,6 +113,7 @@ impl<'a, 'b, Config: endpoint::Config> WriteContext for Context<'a, 'b, Config> self.packet_number } + #[inline] fn write_frame_forced( &mut self, frame: &Frame, @@ -122,22 +129,27 @@ impl<'a, 'b, Config: endpoint::Config> WriteContext for Context<'a, 'b, Config> Some(self.packet_number) } + #[inline] fn ack_elicitation(&self) -> AckElicitation { self.outcome.ack_elicitation } + #[inline] fn packet_number(&self) -> PacketNumber { self.packet_number } + #[inline] fn local_endpoint_type(&self) -> endpoint::Type { Config::ENDPOINT_TYPE } + #[inline] fn header_len(&self) -> usize { self.header_len } + #[inline] fn tag_len(&self) -> usize { self.tag_len } @@ -156,10 +168,12 @@ impl<'a, C: WriteContext> RetransmissionContext<'a, C> { } impl<'a, C: WriteContext> WriteContext for RetransmissionContext<'a, C> { + #[inline] fn current_time(&self) -> Timestamp { self.context.current_time() } + #[inline] fn transmission_constraint(&self) -> transmission::Constraint { debug_assert!( self.context.transmission_constraint().can_retransmit(), @@ -169,14 +183,17 @@ impl<'a, C: WriteContext> WriteContext for RetransmissionContext<'a, C> { transmission::Constraint::RetransmissionOnly } + #[inline] fn transmission_mode(&self) -> Mode { self.context.transmission_mode() } + #[inline] fn remaining_capacity(&self) -> usize { self.context.remaining_capacity() } + #[inline] fn write_frame< Frame: EncoderValue + AckElicitable + CongestionControlled + PathValidationProbing, >( @@ -186,6 +203,7 @@ impl<'a, C: WriteContext> WriteContext for RetransmissionContext<'a, C> { self.context.write_frame(frame) } + #[inline] fn write_fitted_frame< Frame: EncoderValue + AckElicitable + CongestionControlled + PathValidationProbing, >( @@ -195,6 +213,7 @@ impl<'a, C: WriteContext> WriteContext for RetransmissionContext<'a, C> { self.context.write_fitted_frame(frame) } + #[inline] fn write_frame_forced( &mut self, frame: &Frame, @@ -202,22 +221,27 @@ impl<'a, C: WriteContext> WriteContext for RetransmissionContext<'a, C> { self.context.write_frame_forced(frame) } + #[inline] fn ack_elicitation(&self) -> AckElicitation { self.context.ack_elicitation() } + #[inline] fn packet_number(&self) -> PacketNumber { self.context.packet_number() } + #[inline] fn local_endpoint_type(&self) -> endpoint::Type { self.context.local_endpoint_type() } + #[inline] fn header_len(&self) -> usize { self.context.header_len() } + #[inline] fn tag_len(&self) -> usize { self.context.tag_len() }